@@ -20,36 +20,8 @@ def __init__(self, config):
2020 self .config = config
2121
2222 def __call__ (self , model_path = None ):
23- print ("PostExtractProcess" )
2423 if model_path is None :
2524 return False
26- import json
27- import base64
28- import sys
29- import os
30-
31- json_string = json .dumps (self .config )
32- json_bytes = json_string .encode ("utf-8" )
33- b64_encoded_bytes = base64 .b64encode (json_bytes )
34- decorator_config = b64_encoded_bytes .decode ("utf-8" )
35-
36- # args
37- parser = argparse .ArgumentParser (description = "load and run model" )
38- parser .add_argument (
39- "--model-path" ,
40- type = str ,
41- required = True ,
42- help = "Path to folder e.g '../../samples/torch/resnet18'" ,
43- )
44- parser .add_argument (
45- "--decorator-config" ,
46- type = str ,
47- required = False ,
48- default = None ,
49- help = "decorator configuration string" ,
50- )
51- args = parser .parse_args ()
52-
5325 # model
5426 model_class = load_class_from_file (
5527 f"{ model_path } /model.py" , class_name = "GraphModule"
@@ -58,46 +30,20 @@ def __call__(self, model_path=None):
5830 model = model_class ()
5931 print (f"{ model_path = } " )
6032
61- model = _get_decorator (args )(model )
62-
6333 inputs_params = utils .load_converted_from_text (f"{ model_path } " )
6434 params = inputs_params ["weight_info" ]
6535 state_dict = {k : utils .replay_tensor (v ) for k , v in params .items ()}
6636
6737 compiled_num_of_kernels = compile_and_count_kernels (model , state_dict )
68- print ("compiled: nums_of_kernels = " , compiled_num_of_kernels )
6938 if compiled_num_of_kernels == 1 :
70- print ("Graph is fully fusionable " )
39+ print (model_path , "can be fully integrated " )
7140 return True
7241 else :
73- print (f"Graph is not fully fusionable ( { compiled_num_of_kernels } kernels) " )
42+ print (model_path , "can not be fully integrated " )
7443 shutil .rmtree (model_path )
7544 return False
7645
7746
78- def _convert_to_dict (config_str ):
79- if config_str is None :
80- return {}
81- config_str = base64 .b64decode (config_str ).decode ("utf-8" )
82- config = json .loads (config_str )
83- assert isinstance (config , dict ), f"config should be a dict. { config_str = } "
84- return config
85-
86-
87- def _get_decorator (args ):
88- if args .decorator_config is None :
89- return lambda model : model
90- decorator_config = _convert_to_dict (args .decorator_config )
91- if "decorator_path" not in decorator_config :
92- return lambda model : model
93- class_name = decorator_config .get ("decorator_class_name" , "RunModelDecorator" )
94- decorator_class = load_class_from_file (
95- decorator_config ["decorator_path" ],
96- class_name = class_name ,
97- )
98- return decorator_class (decorator_config .get ("decorator_config" , {}))
99-
100-
10147def load_class_from_file (file_path : str , class_name : str ) -> Type [torch .nn .Module ]:
10248 spec = importlib .util .spec_from_file_location ("unnamed" , file_path )
10349 unnamed = importlib .util .module_from_spec (spec )
@@ -133,11 +79,9 @@ def compile_and_count_kernels(gm, sample_inputs) -> int:
13379 ) as prof :
13480 with record_function ("model_inference" ):
13581 output = compiled_gm (** sample_inputs )
136- print (prof .key_averages ().table ()) # print a table of profiler result
13782 events = prof .key_averages ()
13883 if_compile_work = any (e .key == "TorchDynamo Cache Lookup" for e in events )
13984 if not if_compile_work :
140- print ("Compile failed" )
14185 return - 1
14286 for e in events :
14387 if e .key == "cuLaunchKernel" :
0 commit comments