Skip to content

Commit 0cd7c35

Browse files
committed
add tvm
1 parent 0d154bc commit 0cd7c35

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
import inspect
3+
import numpy as np
4+
from .graph_compiler_backend import GraphCompilerBackend
5+
6+
try:
7+
import tvm
8+
from tvm import relax
9+
from tvm import dlight as dl
10+
from tvm.relax.frontend.torch import dynamo_capture_subgraphs
11+
except ImportError:
12+
tvm = None
13+
relax = None
14+
from_exported_program = None
15+
16+
17+
class TvmCompiledModule(torch.nn.Module):
18+
def __init__(self, module, device):
19+
super().__init__()
20+
self.module = module
21+
self.counter = 0
22+
self.tvm_input = []
23+
self.compiled_vm = None
24+
self.dev = tvm.device(device)
25+
self.target = tvm.target.Target.from_device(self.dev)
26+
self.param_names = list(inspect.signature(module.forward).parameters.keys())
27+
28+
def forward(self, **kwargs):
29+
if self.counter == 0:
30+
self.compiled_vm = self.compile(self.module, **kwargs)
31+
for name in self.param_names:
32+
if name in kwargs and name != "s1":
33+
param = kwargs[name]
34+
self.tvm_input.append(tvm.nd.array(param.cpu(), self.dev))
35+
36+
output = self.compiled_vm["subgraph_0"](*self.tvm_input).numpy()
37+
self.counter += 1
38+
return torch.from_numpy(output)
39+
40+
def compile(self, module, **kwargs):
41+
with torch.no_grad():
42+
mod = dynamo_capture_subgraphs(module, **kwargs, keep_params_as_input=True)
43+
mod, _ = relax.frontend.detach_params(mod)
44+
with self.target:
45+
mod = tvm.ir.transform.Sequential(
46+
[
47+
relax.get_pipeline("zero"),
48+
dl.ApplyDefaultSchedule(
49+
dl.gpu.Matmul(),
50+
dl.gpu.GEMV(),
51+
dl.gpu.Reduction(),
52+
dl.gpu.GeneralReduction(),
53+
dl.gpu.Fallback(),
54+
),
55+
]
56+
)(mod)
57+
ex = tvm.compile(mod, target=self.target)
58+
vm = relax.VirtualMachine(ex, self.dev)
59+
return vm
60+
61+
62+
class TvmBackend(GraphCompilerBackend):
63+
def __call__(self, model, **kwargs):
64+
if torch.cuda.is_available():
65+
device = "cuda"
66+
else:
67+
device = "llvm"
68+
return TvmCompiledModule(model, device=device)
69+
70+
def synchronize(self):
71+
if torch.cuda.is_available():
72+
torch.cuda.synchronize()
73+
74+
def version(self):
75+
try:
76+
from importlib.metadata import version
77+
78+
return version("tvm")
79+
except:
80+
return "unknown"

graph_net/torch/test_compiler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import numpy as np
1616
import platform
1717
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
18+
from graph_net.torch.backend.tvm_backend import TvmBackend
1819
from graph_net.torch.backend.inductor_backend import InductorBackend
1920
from graph_net.torch.backend.tensorrt_backend import TensorRTBackend
2021
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
2122

2223
registry_backend = {
24+
"tvm": TvmBackend(),
2325
"inductor": InductorBackend(),
2426
"tensorrt": TensorRTBackend(),
2527
"bladedisc": BladeDISCBackend(),
@@ -35,7 +37,7 @@ def load_class_from_file(
3537

3638
with open(file_path, "r", encoding="utf-8") as f:
3739
model_code = f.read()
38-
model_code = utils.update_device(model_code, args.device)
40+
model_code = utils.modify_code_by_device(model_code, args.device)
3941
spec = importlib.util.spec_from_loader(module_name, loader=None)
4042
module = importlib.util.module_from_spec(spec)
4143
sys.modules[module_name] = module
@@ -226,6 +228,10 @@ def test_single_model(args):
226228

227229
if args.compiler == "inductor":
228230
result_data["configuration"]["compile_framework_version"] = torch.__version__
231+
elif args.compiler == "tvm":
232+
result_data["configuration"][
233+
"compile_framework_version"
234+
] = f"Tvm {compiler.version}"
229235
elif args.compiler == "tensorrt":
230236
result_data["configuration"][
231237
"compile_framework_version"
@@ -245,6 +251,7 @@ def test_single_model(args):
245251

246252
expected_out = eager_model_call()
247253
compiled_out = compiled_model_call()
254+
compiled_out = (tensor.to(args.device) for tensor in compiled_out)
248255

249256
def print_and_store_cmp(key, func, **kwargs):
250257
cmp_ret = func(expected_out, compiled_out, **kwargs)

graph_net/torch/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,14 @@ def replay_tensor(info):
272272
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
273273

274274

275-
def update_device(code, device):
275+
def modify_code_by_device(code, device):
276276
if device == "cuda":
277277
pattern = r'device\(type="cpu"\)'
278278
replacement = f'device(type="cuda", index={torch.cuda.current_device()})'
279-
updated_code = re.sub(pattern, replacement, code)
280-
return updated_code
279+
modify_code = re.sub(pattern, replacement, code)
280+
return modify_code
281281
else:
282282
pattern = r'device\(type="cuda"(?:, index=\d+)?\)'
283283
replacement = 'device(type="cpu")'
284-
updated_code = re.sub(pattern, replacement, code)
285-
return updated_code
284+
modify_code = re.sub(pattern, replacement, code)
285+
return modify_code

0 commit comments

Comments
 (0)