Skip to content

Commit 2a546d9

Browse files
committed
fix
1 parent 0cd7c35 commit 2a546d9

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

graph_net/torch/backend/tvm_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def forward(self, **kwargs):
3333
param = kwargs[name]
3434
self.tvm_input.append(tvm.nd.array(param.cpu(), self.dev))
3535

36-
output = self.compiled_vm["subgraph_0"](*self.tvm_input).numpy()
36+
output = self.compiled_vm["subgraph_0"](*self.tvm_input)
3737
self.counter += 1
38-
return torch.from_numpy(output)
38+
return torch.from_dlpack(output)
3939

4040
def compile(self, module, **kwargs):
4141
with torch.no_grad():

graph_net/torch/test_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def test_single_model(args):
251251

252252
expected_out = eager_model_call()
253253
compiled_out = compiled_model_call()
254-
compiled_out = (tensor.to(args.device) for tensor in compiled_out)
255254

256255
def print_and_store_cmp(key, func, **kwargs):
257256
cmp_ret = func(expected_out, compiled_out, **kwargs)

0 commit comments

Comments
 (0)