Skip to content

Commit 302edeb

Browse files
committed
fix bug
1 parent defb7e4 commit 302edeb

File tree

2 files changed

+50
-23
lines changed

2 files changed

+50
-23
lines changed

graph_net/torch/test_compiler.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,42 @@
1212
from dataclasses import dataclass
1313
from contextlib import contextmanager
1414
import time
15-
import torch_tensorrt
1615

17-
18-
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
16+
try:
17+
import torch_tensorrt
18+
except ImportError:
19+
torch_tensorrt = None
20+
21+
22+
registry_backend = {
23+
"inductor": {
24+
"compiler": torch.compile,
25+
"backend": "inductor",
26+
"synchronizer": torch.cuda.synchronize,
27+
},
28+
"tensorrt": {
29+
"compiler": torch.compile,
30+
"backend": "tensorrt",
31+
"synchronizer": torch.cuda.synchronize,
32+
},
33+
"default": {
34+
"compiler": torch.compile,
35+
"backend": "inductor",
36+
"synchronizer": torch.cuda.synchronize,
37+
},
38+
}
39+
40+
41+
def load_class_from_file(
42+
args: argparse.Namespace, class_name: str
43+
) -> Type[torch.nn.Module]:
44+
file_path = f"{args.model_path}/model.py"
1945
file = Path(file_path).resolve()
2046
module_name = file.stem
2147

2248
with open(file_path, "r", encoding="utf-8") as f:
2349
original_code = f.read()
24-
if torch.cuda.is_available():
50+
if args.device == "cuda":
2551
modified_code = original_code.replace("cpu", "cuda")
2652
else:
2753
modified_code = original_code
@@ -36,38 +62,32 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
3662

3763

3864
def get_compiler(args):
39-
if args.compiler == "tensorrt":
40-
return torch.compile
41-
else:
42-
assert args.compiler == "default"
43-
return torch.compile
65+
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
66+
return registry_backend[args.compiler]["compiler"]
4467

4568

4669
def get_backend(args):
47-
if args.compiler == "tensorrt":
48-
return "tensorrt"
49-
else:
50-
assert args.compiler == "default"
51-
return "inductor"
70+
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
71+
return registry_backend[args.compiler]["backend"]
5272

5373

5474
def get_synchronizer_func(args):
55-
assert args.compiler == "default" or args.compiler == "tensorrt"
56-
return torch.cuda.synchronize
75+
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
76+
return registry_backend[args.compiler]["synchronizer"]
5777

5878

5979
def get_model(args):
60-
model_class = load_class_from_file(
61-
f"{args.model_path}/model.py", class_name="GraphModule"
62-
)
63-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64-
return model_class().to(device)
80+
model_class = load_class_from_file(args, class_name="GraphModule")
81+
return model_class().to(torch.device(args.device))
6582

6683

6784
def get_input_dict(args):
6885
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
6986
params = inputs_params["weight_info"]
70-
return {k: utils.replay_tensor(v) for k, v in params.items()}
87+
return {
88+
k: utils.replay_tensor(v).to(torch.device(args.device))
89+
for k, v in params.items()
90+
}
7191

7292

7393
@dataclass
@@ -228,6 +248,13 @@ def main(args):
228248
default="default",
229249
help="Path to customized compiler python file",
230250
)
251+
parser.add_argument(
252+
"--device",
253+
type=str,
254+
required=False,
255+
default="cpu",
256+
help="Device for testing the compiler",
257+
)
231258
parser.add_argument(
232259
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
233260
)

graph_net/torch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def extract_dynamic_shapes(example_inputs):
257257

258258

259259
def replay_tensor(info):
260-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
260+
device = info["info"]["device"]
261261
dtype = info["info"]["dtype"]
262262
shape = info["info"]["shape"]
263263

0 commit comments

Comments
 (0)