1919 torch_tensorrt = None
2020
2121
22+ class GraphCompilerBackend :
23+ def __call__ (self , model ):
24+ raise NotImplementedError ()
25+
26+ def synchronize (self ):
27+ raise NotImplementedError ()
28+
29+
30+ class InductorBackend (GraphCompilerBackend ):
31+ def __call__ (self , model ):
32+ return torch .compile (model , backend = "inductor" )
33+
34+ def synchronize (self ):
35+ torch .cuda .synchronize ()
36+
37+
38+ class TensorRTBackend (GraphCompilerBackend ):
39+ def __call__ (self , model ):
40+ return torch .compile (model , backend = "tensorrt" )
41+
42+ def synchronize (self ):
43+ torch .cuda .synchronize ()
44+
45+
46+ class DefaultBackend (InductorBackend ):
47+ pass
48+
49+
2250registry_backend = {
23- "inductor" : {
24- "compiler" : torch .compile ,
25- "backend" : "inductor" ,
26- "synchronizer" : torch .cuda .synchronize ,
27- },
28- "tensorrt" : {
29- "compiler" : torch .compile ,
30- "backend" : "tensorrt" ,
31- "synchronizer" : torch .cuda .synchronize ,
32- },
33- "default" : {
34- "compiler" : torch .compile ,
35- "backend" : "inductor" ,
36- "synchronizer" : torch .cuda .synchronize ,
37- },
51+ "inductor" : InductorBackend (),
52+ "tensorrt" : TensorRTBackend (),
53+ "default" : DefaultBackend (),
3854}
3955
4056
@@ -61,19 +77,9 @@ def load_class_from_file(
6177 return model_class
6278
6379
64- def get_compiler (args ):
65- assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
66- return registry_backend [args .compiler ]["compiler" ]
67-
68-
69- def get_backend (args ):
70- assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
71- return registry_backend [args .compiler ]["backend" ]
72-
73-
74- def get_synchronizer_func (args ):
80+ def get_compiler_backend (args ) -> GraphCompilerBackend :
7581 assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
76- return registry_backend [args .compiler ][ "synchronizer" ]
82+ return registry_backend [args .compiler ]
7783
7884
7985def get_model (args ):
@@ -106,16 +112,14 @@ def naive_timer(duration_box, get_synchronizer_func):
106112
107113
108114def test_single_model (args ):
109- compiler = get_compiler (args )
110- backend = get_backend (args )
111- synchronizer_func = get_synchronizer_func (args )
115+ compiler = get_compiler_backend (args )
112116 input_dict = get_input_dict (args )
113117 model = get_model (args )
114- compiled_model = compiler (model , backend = backend )
118+ compiled_model = compiler (model )
115119
116120 # eager
117121 eager_duration_box = DurationBox (- 1 )
118- with naive_timer (eager_duration_box , synchronizer_func ):
122+ with naive_timer (eager_duration_box , compiler . synchronize ):
119123 expected_out = model (** input_dict )
120124
121125 # warmup
@@ -124,7 +128,7 @@ def test_single_model(args):
124128
125129 # compiled
126130 compiled_duration_box = DurationBox (- 1 )
127- with naive_timer (compiled_duration_box , synchronizer_func ):
131+ with naive_timer (compiled_duration_box , compiler . synchronize ):
128132 compiled_out = compiled_model (** input_dict )
129133
130134 def print_cmp (key , func , ** kwargs ):
0 commit comments