1313from contextlib import contextmanager
1414import time
1515
16+ try :
17+ import torch_tensorrt
18+ except ImportError :
19+ torch_tensorrt = None
1620
17- def load_class_from_file (file_path : str , class_name : str ) -> Type [torch .nn .Module ]:
21+
22+ class GraphCompilerBackend :
23+ def __call__ (self , model ):
24+ raise NotImplementedError ()
25+
26+ def synchronize (self ):
27+ raise NotImplementedError ()
28+
29+
30+ class InductorBackend (GraphCompilerBackend ):
31+ def __call__ (self , model ):
32+ return torch .compile (model , backend = "inductor" )
33+
34+ def synchronize (self ):
35+ torch .cuda .synchronize ()
36+
37+
38+ class TensorRTBackend (GraphCompilerBackend ):
39+ def __call__ (self , model ):
40+ return torch .compile (model , backend = "tensorrt" )
41+
42+ def synchronize (self ):
43+ torch .cuda .synchronize ()
44+
45+
46+ registry_backend = {
47+ "inductor" : InductorBackend (),
48+ "tensorrt" : TensorRTBackend (),
49+ "default" : InductorBackend (),
50+ }
51+
52+
53+ def load_class_from_file (
54+ args : argparse .Namespace , class_name : str
55+ ) -> Type [torch .nn .Module ]:
56+ file_path = f"{ args .model_path } /model.py"
1857 file = Path (file_path ).resolve ()
1958 module_name = file .stem
2059
2160 with open (file_path , "r" , encoding = "utf-8" ) as f :
2261 original_code = f .read ()
23- import_stmt = "import torch"
24- modified_code = f"{ import_stmt } \n { original_code } "
62+ if args .device == "cuda" :
63+ modified_code = original_code .replace ("cpu" , "cuda" )
64+ else :
65+ modified_code = original_code
2566 spec = importlib .util .spec_from_loader (module_name , loader = None )
2667 module = importlib .util .module_from_spec (spec )
2768 sys .modules [module_name ] = module
@@ -32,27 +73,23 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
3273 return model_class
3374
3475
35- def get_compiler (args ):
36- assert args .compiler == "default"
37- return torch .compile
38-
39-
40- def get_synchronizer_func (args ):
41- assert args .compiler == "default"
42- return torch .cuda .synchronize
76+ def get_compiler_backend (args ) -> GraphCompilerBackend :
77+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
78+ return registry_backend [args .compiler ]
4379
4480
4581def get_model (args ):
46- model_class = load_class_from_file (
47- f"{ args .model_path } /model.py" , class_name = "GraphModule"
48- )
49- return model_class ()
82+ model_class = load_class_from_file (args , class_name = "GraphModule" )
83+ return model_class ().to (torch .device (args .device ))
5084
5185
5286def get_input_dict (args ):
5387 inputs_params = utils .load_converted_from_text (f"{ args .model_path } " )
5488 params = inputs_params ["weight_info" ]
55- return {k : utils .replay_tensor (v ) for k , v in params .items ()}
89+ return {
90+ k : utils .replay_tensor (v ).to (torch .device (args .device ))
91+ for k , v in params .items ()
92+ }
5693
5794
5895@dataclass
@@ -71,15 +108,14 @@ def naive_timer(duration_box, get_synchronizer_func):
71108
72109
73110def test_single_model (args ):
74- compiler = get_compiler (args )
75- synchronizer_func = get_synchronizer_func (args )
111+ compiler = get_compiler_backend (args )
76112 input_dict = get_input_dict (args )
77113 model = get_model (args )
78114 compiled_model = compiler (model )
79115
80116 # eager
81117 eager_duration_box = DurationBox (- 1 )
82- with naive_timer (eager_duration_box , synchronizer_func ):
118+ with naive_timer (eager_duration_box , compiler . synchronize ):
83119 expected_out = model (** input_dict )
84120
85121 # warmup
@@ -88,7 +124,7 @@ def test_single_model(args):
88124
89125 # compiled
90126 compiled_duration_box = DurationBox (- 1 )
91- with naive_timer (compiled_duration_box , synchronizer_func ):
127+ with naive_timer (compiled_duration_box , compiler . synchronize ):
92128 compiled_out = compiled_model (** input_dict )
93129
94130 def print_cmp (key , func , ** kwargs ):
@@ -157,11 +193,11 @@ def test_multi_models(args):
157193 cmd = "" .join (
158194 [
159195 sys .executable ,
160- "-m graph_net.torch.test_compiler" ,
161- f"--model-path { model_path } " ,
162- f"--compiler { args .compiler } " ,
163- f"--warmup { args .warmup } " ,
164- f"--log-prompt { args .log_prompt } " ,
196+ " -m graph_net.torch.test_compiler" ,
197+ f" --model-path { model_path } " ,
198+ f" --compiler { args .compiler } " ,
199+ f" --warmup { args .warmup } " ,
200+ f" --log-prompt { args .log_prompt } " ,
165201 ]
166202 )
167203 cmd_ret = os .system (cmd )
@@ -212,6 +248,13 @@ def main(args):
212248 default = "default" ,
213249 help = "Path to customized compiler python file" ,
214250 )
251+ parser .add_argument (
252+ "--device" ,
253+ type = str ,
254+ required = False ,
255+ default = "cpu" ,
256+ help = "Device for testing the compiler" ,
257+ )
215258 parser .add_argument (
216259 "--warmup" , type = int , required = False , default = 5 , help = "Number of warmup steps"
217260 )
0 commit comments