1212from dataclasses import dataclass
1313from contextlib import contextmanager
1414import time
15- import torch_tensorrt
1615
17-
18- def load_class_from_file (file_path : str , class_name : str ) -> Type [torch .nn .Module ]:
16+ try :
17+ import torch_tensorrt
18+ except ImportError :
19+ torch_tensorrt = None
20+
21+
22+ registry_backend = {
23+ "inductor" : {
24+ "compiler" : torch .compile ,
25+ "backend" : "inductor" ,
26+ "synchronizer" : torch .cuda .synchronize ,
27+ },
28+ "tensorrt" : {
29+ "compiler" : torch .compile ,
30+ "backend" : "tensorrt" ,
31+ "synchronizer" : torch .cuda .synchronize ,
32+ },
33+ "default" : {
34+ "compiler" : torch .compile ,
35+ "backend" : "inductor" ,
36+ "synchronizer" : torch .cuda .synchronize ,
37+ },
38+ }
39+
40+
41+ def load_class_from_file (
42+ args : argparse .Namespace , class_name : str
43+ ) -> Type [torch .nn .Module ]:
44+ file_path = f"{ args .model_path } /model.py"
1945 file = Path (file_path ).resolve ()
2046 module_name = file .stem
2147
2248 with open (file_path , "r" , encoding = "utf-8" ) as f :
2349 original_code = f .read ()
24- if torch . cuda . is_available () :
50+ if args . device == " cuda" :
2551 modified_code = original_code .replace ("cpu" , "cuda" )
2652 else :
2753 modified_code = original_code
@@ -36,38 +62,32 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
3662
3763
3864def get_compiler (args ):
39- if args .compiler == "tensorrt" :
40- return torch .compile
41- else :
42- assert args .compiler == "default"
43- return torch .compile
65+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
66+ return registry_backend [args .compiler ]["compiler" ]
4467
4568
4669def get_backend (args ):
47- if args .compiler == "tensorrt" :
48- return "tensorrt"
49- else :
50- assert args .compiler == "default"
51- return "inductor"
70+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
71+ return registry_backend [args .compiler ]["backend" ]
5272
5373
5474def get_synchronizer_func (args ):
55- assert args .compiler == "default" or args .compiler == "tensorrt "
56- return torch . cuda . synchronize
75+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
76+ return registry_backend [ args . compiler ][ "synchronizer" ]
5777
5878
5979def get_model (args ):
60- model_class = load_class_from_file (
61- f"{ args .model_path } /model.py" , class_name = "GraphModule"
62- )
63- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
64- return model_class ().to (device )
80+ model_class = load_class_from_file (args , class_name = "GraphModule" )
81+ return model_class ().to (torch .device (args .device ))
6582
6683
6784def get_input_dict (args ):
6885 inputs_params = utils .load_converted_from_text (f"{ args .model_path } " )
6986 params = inputs_params ["weight_info" ]
70- return {k : utils .replay_tensor (v ) for k , v in params .items ()}
87+ return {
88+ k : utils .replay_tensor (v ).to (torch .device (args .device ))
89+ for k , v in params .items ()
90+ }
7191
7292
7393@dataclass
@@ -228,6 +248,13 @@ def main(args):
228248 default = "default" ,
229249 help = "Path to customized compiler python file" ,
230250 )
251+ parser .add_argument (
252+ "--device" ,
253+ type = str ,
254+ required = False ,
255+ default = "cpu" ,
256+ help = "Device for testing the compiler" ,
257+ )
231258 parser .add_argument (
232259 "--warmup" , type = int , required = False , default = 5 , help = "Number of warmup steps"
233260 )
0 commit comments