1414import json
1515import numpy as np
1616import platform
17-
18- try :
19- import torch_tensorrt
20- except ImportError :
21- torch_tensorrt = None
22-
23- try :
24- import torch_blade
25- except ImportError :
26- torch_blade = None
27-
28-
29- class GraphCompilerBackend :
30- def __call__ (self , model ):
31- raise NotImplementedError ()
32-
33- def synchronize (self ):
34- raise NotImplementedError ()
35-
36-
37- class InductorBackend (GraphCompilerBackend ):
38- def __call__ (self , model ):
39- return torch .compile (model , backend = "inductor" )
40-
41- def synchronize (self ):
42- if torch .cuda .is_available ():
43- torch .cuda .synchronize ()
44-
45-
46- class TensorRTBackend (GraphCompilerBackend ):
47- def __call__ (self , model ):
48- return torch .compile (model , backend = "tensorrt" )
49-
50- def synchronize (self ):
51- torch .cuda .synchronize ()
17+ from .graph_compiler_backend import (
18+ GraphCompilerBackend ,
19+ InductorBackend ,
20+ TensorRTBackend ,
21+ )
22+ from .blade_disc_backend import BladeDISCBackend
5223
5324
5425def load_class_from_file (
@@ -70,9 +41,25 @@ def load_class_from_file(
7041 return model_class
7142
7243
44+ registry_backend_classes = {
45+ "inductor" : InductorBackend ,
46+ "tensorrt" : TensorRTBackend ,
47+ "bladedisc" : BladeDISCBackend ,
48+ }
49+
50+
7351def get_compiler_backend (args ) -> GraphCompilerBackend :
74- assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
75- return registry_backend [args .compiler ]
52+ assert (
53+ args .compiler in registry_backend_classes
54+ ), f"Unknown compiler: { args .compiler } "
55+ cls = registry_backend_classes [args .compiler ]
56+ if cls == InductorBackend :
57+ return InductorBackend ()
58+ elif cls == TensorRTBackend :
59+ return TensorRTBackend ()
60+ elif cls == BladeDISCBackend :
61+ input_dict = get_input_dict (args )
62+ return BladeDISCBackend (input_dict )
7663
7764
7865def get_model (args ):
@@ -92,33 +79,6 @@ def get_input_dict(args):
9279 }
9380
9481
95- class BladeDISCBackend (GraphCompilerBackend ):
96- def __init__ (self , input_dict = None ):
97- self .input_dict = input_dict
98-
99- def __call__ (self , model ):
100- torch_config = torch_blade .config .Config ()
101- torch_config .enable_mlir_amp = False
102- with torch .no_grad (), torch_config :
103- input_dict = get_input_dict (args )
104- dummy_input = tuple (input_dict .values ())
105- compiled_model = torch_blade .optimize (
106- model , allow_tracing = True , model_inputs = dummy_input
107- )
108- return compiled_model
109-
110- def synchronize (self ):
111- if torch .cuda .is_available ():
112- torch .cuda .synchronize ()
113-
114-
115- registry_backend = {
116- "inductor" : InductorBackend (),
117- "tensorrt" : TensorRTBackend (),
118- "bladedisc" : BladeDISCBackend (),
119- }
120-
121-
12282@dataclass
12383class DurationBox :
12484 value : float
0 commit comments