Skip to content

Commit 51b2302

Browse files
committed
Fix test_compiler
1 parent d790478 commit 51b2302

File tree

4 files changed

+82
-65
lines changed

4 files changed

+82
-65
lines changed

docs/BladeDISC_tech_report.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The process of compiling and optimizing with `torch.jit.trace` or `torch.jit.scr
1515
```shell
1616
# allow_tracing=True using torch.jit.trace(model, inputs)
1717
compiled_model = torch_blade.optimize(model, allow_tracing=True, model_inputs=tuple(inputs))
18-
# allow_tracing=False using torch.jit.script(model) 在本例中的尝试
18+
# allow_tracing=False using torch.jit.script(model)
1919
compiled_model = torch_blade.optimize(model, allow_tracing=False)
2020
```
2121

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
from .graph_compiler_backend import GraphCompilerBackend
3+
4+
try:
5+
import torch_blade
6+
except ImportError:
7+
torch_blade = None
8+
9+
10+
class BladeDISCBackend(GraphCompilerBackend):
11+
def __init__(self, input_dict):
12+
self.input_dict = input_dict
13+
14+
def __call__(self, model):
15+
torch_config = torch_blade.config.Config()
16+
torch_config.enable_mlir_amp = False
17+
with torch.no_grad(), torch_config:
18+
dummy_input = tuple(self.input_dict.values())
19+
compiled_model = torch_blade.optimize(
20+
model, allow_tracing=True, model_inputs=dummy_input
21+
)
22+
return compiled_model
23+
24+
def synchronize(self):
25+
if torch.cuda.is_available():
26+
torch.cuda.synchronize()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
3+
try:
4+
import torch_tensorrt
5+
except ImportError:
6+
torch_tensorrt = None
7+
8+
9+
class GraphCompilerBackend:
10+
def __call__(self, model):
11+
raise NotImplementedError()
12+
13+
def synchronize(self):
14+
raise NotImplementedError()
15+
16+
17+
class InductorBackend(GraphCompilerBackend):
18+
def __call__(self, model):
19+
return torch.compile(model, backend="inductor")
20+
21+
def synchronize(self):
22+
if torch.cuda.is_available():
23+
torch.cuda.synchronize()
24+
25+
26+
class TensorRTBackend(GraphCompilerBackend):
27+
def __call__(self, model):
28+
return torch.compile(model, backend="tensorrt")
29+
30+
def synchronize(self):
31+
torch.cuda.synchronize()

graph_net/torch/test_compiler.py

Lines changed: 24 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,12 @@
1414
import json
1515
import numpy as np
1616
import platform
17-
18-
try:
19-
import torch_tensorrt
20-
except ImportError:
21-
torch_tensorrt = None
22-
23-
try:
24-
import torch_blade
25-
except ImportError:
26-
torch_blade = None
27-
28-
29-
class GraphCompilerBackend:
30-
def __call__(self, model):
31-
raise NotImplementedError()
32-
33-
def synchronize(self):
34-
raise NotImplementedError()
35-
36-
37-
class InductorBackend(GraphCompilerBackend):
38-
def __call__(self, model):
39-
return torch.compile(model, backend="inductor")
40-
41-
def synchronize(self):
42-
if torch.cuda.is_available():
43-
torch.cuda.synchronize()
44-
45-
46-
class TensorRTBackend(GraphCompilerBackend):
47-
def __call__(self, model):
48-
return torch.compile(model, backend="tensorrt")
49-
50-
def synchronize(self):
51-
torch.cuda.synchronize()
17+
from .graph_compiler_backend import (
18+
GraphCompilerBackend,
19+
InductorBackend,
20+
TensorRTBackend,
21+
)
22+
from .blade_disc_backend import BladeDISCBackend
5223

5324

5425
def load_class_from_file(
@@ -70,9 +41,25 @@ def load_class_from_file(
7041
return model_class
7142

7243

44+
registry_backend_classes = {
45+
"inductor": InductorBackend,
46+
"tensorrt": TensorRTBackend,
47+
"bladedisc": BladeDISCBackend,
48+
}
49+
50+
7351
def get_compiler_backend(args) -> GraphCompilerBackend:
74-
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
75-
return registry_backend[args.compiler]
52+
assert (
53+
args.compiler in registry_backend_classes
54+
), f"Unknown compiler: {args.compiler}"
55+
cls = registry_backend_classes[args.compiler]
56+
if cls == InductorBackend:
57+
return InductorBackend()
58+
elif cls == TensorRTBackend:
59+
return TensorRTBackend()
60+
elif cls == BladeDISCBackend:
61+
input_dict = get_input_dict(args)
62+
return BladeDISCBackend(input_dict)
7663

7764

7865
def get_model(args):
@@ -92,33 +79,6 @@ def get_input_dict(args):
9279
}
9380

9481

95-
class BladeDISCBackend(GraphCompilerBackend):
96-
def __init__(self, input_dict=None):
97-
self.input_dict = input_dict
98-
99-
def __call__(self, model):
100-
torch_config = torch_blade.config.Config()
101-
torch_config.enable_mlir_amp = False
102-
with torch.no_grad(), torch_config:
103-
input_dict = get_input_dict(args)
104-
dummy_input = tuple(input_dict.values())
105-
compiled_model = torch_blade.optimize(
106-
model, allow_tracing=True, model_inputs=dummy_input
107-
)
108-
return compiled_model
109-
110-
def synchronize(self):
111-
if torch.cuda.is_available():
112-
torch.cuda.synchronize()
113-
114-
115-
registry_backend = {
116-
"inductor": InductorBackend(),
117-
"tensorrt": TensorRTBackend(),
118-
"bladedisc": BladeDISCBackend(),
119-
}
120-
121-
12282
@dataclass
12383
class DurationBox:
12484
value: float

0 commit comments

Comments
 (0)