Skip to content

Commit b8e683c

Browse files
committed
fix code
1 parent 302edeb commit b8e683c

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

graph_net/torch/test_compiler.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,38 @@
1919
torch_tensorrt = None
2020

2121

22+
class GraphCompilerBackend:
23+
def __call__(self, model):
24+
raise NotImplementedError()
25+
26+
def synchronize(self):
27+
raise NotImplementedError()
28+
29+
30+
class InductorBackend(GraphCompilerBackend):
31+
def __call__(self, model):
32+
return torch.compile(model, backend="inductor")
33+
34+
def synchronize(self):
35+
torch.cuda.synchronize()
36+
37+
38+
class TensorRTBackend(GraphCompilerBackend):
39+
def __call__(self, model):
40+
return torch.compile(model, backend="tensorrt")
41+
42+
def synchronize(self):
43+
torch.cuda.synchronize()
44+
45+
46+
class DefaultBackend(InductorBackend):
47+
pass
48+
49+
2250
registry_backend = {
23-
"inductor": {
24-
"compiler": torch.compile,
25-
"backend": "inductor",
26-
"synchronizer": torch.cuda.synchronize,
27-
},
28-
"tensorrt": {
29-
"compiler": torch.compile,
30-
"backend": "tensorrt",
31-
"synchronizer": torch.cuda.synchronize,
32-
},
33-
"default": {
34-
"compiler": torch.compile,
35-
"backend": "inductor",
36-
"synchronizer": torch.cuda.synchronize,
37-
},
51+
"inductor": InductorBackend(),
52+
"tensorrt": TensorRTBackend(),
53+
"default": DefaultBackend(),
3854
}
3955

4056

@@ -61,19 +77,9 @@ def load_class_from_file(
6177
return model_class
6278

6379

64-
def get_compiler(args):
65-
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
66-
return registry_backend[args.compiler]["compiler"]
67-
68-
69-
def get_backend(args):
70-
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
71-
return registry_backend[args.compiler]["backend"]
72-
73-
74-
def get_synchronizer_func(args):
80+
def get_compiler_backend(args) -> GraphCompilerBackend:
7581
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
76-
return registry_backend[args.compiler]["synchronizer"]
82+
return registry_backend[args.compiler]
7783

7884

7985
def get_model(args):
@@ -106,16 +112,14 @@ def naive_timer(duration_box, get_synchronizer_func):
106112

107113

108114
def test_single_model(args):
109-
compiler = get_compiler(args)
110-
backend = get_backend(args)
111-
synchronizer_func = get_synchronizer_func(args)
115+
compiler = get_compiler_backend(args)
112116
input_dict = get_input_dict(args)
113117
model = get_model(args)
114-
compiled_model = compiler(model, backend=backend)
118+
compiled_model = compiler(model)
115119

116120
# eager
117121
eager_duration_box = DurationBox(-1)
118-
with naive_timer(eager_duration_box, synchronizer_func):
122+
with naive_timer(eager_duration_box, compiler.synchronize):
119123
expected_out = model(**input_dict)
120124

121125
# warmup
@@ -124,7 +128,7 @@ def test_single_model(args):
124128

125129
# compiled
126130
compiled_duration_box = DurationBox(-1)
127-
with naive_timer(compiled_duration_box, synchronizer_func):
131+
with naive_timer(compiled_duration_box, compiler.synchronize):
128132
compiled_out = compiled_model(**input_dict)
129133

130134
def print_cmp(key, func, **kwargs):

0 commit comments

Comments
 (0)