Skip to content

Commit 1ecdcdd

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into develop
2 parents 6092ffe + 0d154bc commit 1ecdcdd

File tree

6 files changed

+27
-8
lines changed

6 files changed

+27
-8
lines changed

graph_net/torch/blade_disc_backend.py renamed to graph_net/torch/backend/blade_disc_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ def __call__(self, model):
3434
def synchronize(self):
3535
if torch.cuda.is_available():
3636
torch.cuda.synchronize()
37+
38+
def version(self):
39+
return torch_blade.version
File renamed without changes.
File renamed without changes.
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import torch
2+
from .graph_compiler_backend import GraphCompilerBackend
23

34
try:
45
import torch_tensorrt
56
except ImportError:
67
torch_tensorrt = None
78

8-
from .graph_compiler_backend import GraphCompilerBackend
9-
109

1110
class TensorRTBackend(GraphCompilerBackend):
1211
def __call__(self, model):
@@ -16,3 +15,6 @@ def __call__(self, model):
1615

1716
def synchronize(self):
1817
torch.cuda.synchronize()
18+
19+
def version(self):
20+
return torch_tensorrt.version

graph_net/torch/test_compiler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import json
1515
import numpy as np
1616
import platform
17-
from graph_net.torch.graph_compiler_backend import GraphCompilerBackend
18-
from graph_net.torch.inductor_backend import InductorBackend
19-
from graph_net.torch.tensorrt_backend import TensorRTBackend
20-
from graph_net.torch.blade_disc_backend import BladeDISCBackend
17+
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
18+
from graph_net.torch.backend.inductor_backend import InductorBackend
19+
from graph_net.torch.backend.tensorrt_backend import TensorRTBackend
20+
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2121

2222
registry_backend = {
2323
"inductor": InductorBackend(),
@@ -35,6 +35,7 @@ def load_class_from_file(
3535

3636
with open(file_path, "r", encoding="utf-8") as f:
3737
model_code = f.read()
38+
model_code = utils.update_device(model_code, args.device)
3839
spec = importlib.util.spec_from_loader(module_name, loader=None)
3940
module = importlib.util.module_from_spec(spec)
4041
sys.modules[module_name] = module
@@ -196,11 +197,11 @@ def test_single_model(args):
196197
elif args.compiler == "tensorrt":
197198
result_data["configuration"][
198199
"compile_framework_version"
199-
] = f"TensorRT {torch_tensorrt.version}"
200+
] = f"TensorRT {compiler.version}"
200201
elif args.compiler == "bladedisc":
201202
result_data["configuration"][
202203
"compile_framework_version"
203-
] = f"BladeDISC {torch_blade.version}"
204+
] = f"BladeDISC {compiler.version}"
204205
else:
205206
result_data["configuration"]["compiler_version"] = "unknown"
206207

graph_net/torch/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,16 @@ def replay_tensor(info):
270270
if dtype is torch.bool:
271271
return (torch.randn(size=shape) > 0.5).to(dtype).to(device)
272272
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
273+
274+
275+
def update_device(code, device):
276+
if device == "cuda":
277+
pattern = r'device\(type="cpu"\)'
278+
replacement = f'device(type="cuda", index={torch.cuda.current_device()})'
279+
updated_code = re.sub(pattern, replacement, code)
280+
return updated_code
281+
else:
282+
pattern = r'device\(type="cuda"(?:, index=\d+)?\)'
283+
replacement = 'device(type="cpu")'
284+
updated_code = re.sub(pattern, replacement, code)
285+
return updated_code

0 commit comments

Comments
 (0)