11from graph_net .torch import utils
2- import argparse
32import importlib .util
4- import inspect
53import shutil
64import torch
7- import logging
8- from pathlib import Path
9- from typing import Type , Any
10- import sys
11- import json
12- import base64
13- from contextlib import contextmanager
14-
5+ from typing import Type
156from torch .profiler import profile , record_function , ProfilerActivity
167
178
@@ -34,7 +25,9 @@ def __call__(self, model_path=None):
3425 params = inputs_params ["weight_info" ]
3526 state_dict = {k : utils .replay_tensor (v ) for k , v in params .items ()}
3627
37- compiled_num_of_kernels = compile_and_count_kernels (model , state_dict )
28+ model (** state_dict )
29+ compiled_model = torch .compile (model )
30+ compiled_num_of_kernels = count_kernels (model , state_dict )
3831 if compiled_num_of_kernels == 1 :
3932 print (model_path , "can be fully integrated" )
4033 return True
@@ -52,12 +45,12 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
5245 return model_class
5346
5447
55- def compile_and_count_kernels ( gm , sample_inputs ) -> int :
48+ def count_kernels ( model , sample_inputs ) -> int :
5649 """
5750 Count the number of CUDA kernel launches performed during a model's forward pass.
5851
5952 Args:
60- gm (graph models)
53+ model (graph models)
6154 sample_inputs(tensors)
6255
6356 Returns:
@@ -68,21 +61,19 @@ def compile_and_count_kernels(gm, sample_inputs) -> int:
6861 - Identifies the event with key = 'cudaLaunchKernel', which corresponds
6962 to the number of CUDA kernel launches.
7063 """
71- gm .eval ()
64+ model .eval ()
7265 # Use PyTorch Profiler
73- compiled_gm = torch .compile (gm )
74- _ = compiled_gm (** sample_inputs )
7566
7667 with profile (
7768 activities = [ProfilerActivity .CUDA , ProfilerActivity .CPU ],
7869 record_shapes = True ,
7970 ) as prof :
8071 with record_function ("model_inference" ):
81- output = compiled_gm (** sample_inputs )
72+ output = model (** sample_inputs )
8273 events = prof .key_averages ()
83- if_compile_work = any (e .key == "TorchDynamo Cache Lookup" for e in events )
84- if not if_compile_work :
85- return - 1
74+
75+ total_count = 0
8676 for e in events :
87- if e .key == "cuLaunchKernel" :
88- return e .count
77+ if e .key == "cuLaunchKernel" or e .key == "cudalaunchKernel" :
78+ total_count += e .count
79+ return total_count
0 commit comments