Skip to content

Commit 5ae3db2

Browse files
[Relax][PyTorch] Add stack.default and sum.default to exported programs translator (#17814)
* stack correct * sum correct in side script * all pass
1 parent 5842bdb commit 5ae3db2

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def create_convert_map(
381381
"mean.dim": self._mean,
382382
"prod.default": self._prod,
383383
"std.correction": self._std,
384+
"sum.default": self._sum,
384385
"sum.dim_IntList": self._sum,
385386
"var.correction": self._var,
386387
# search
@@ -409,6 +410,7 @@ def create_convert_map(
409410
"split_with_sizes.default": self._split,
410411
"squeeze.default": self._squeeze,
411412
"squeeze.dim": self._squeeze,
413+
"stack.default": self._stack,
412414
"take.default": self._take,
413415
"tile.default": self._tile,
414416
"topk.default": self._topk,

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,5 +512,33 @@ def forward(self, x):
512512
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
513513

514514

515+
@tvm.testing.parametrize_targets("cuda")
516+
def test_stack(target, dev):
517+
class StackModel(nn.Module):
518+
def forward(self, x):
519+
val1 = x[1, 4]
520+
val2 = x[3, 2]
521+
val3 = x[5, 6]
522+
z = torch.stack([val1, val2, val3])
523+
return z
524+
525+
torch_module = StackModel().eval()
526+
raw_data = np.random.rand(10, 10, 10).astype("float32")
527+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
528+
529+
530+
@tvm.testing.parametrize_targets("cuda")
531+
def test_sum(target, dev):
532+
class SumModel(nn.Module):
533+
def forward(self, x):
534+
new_vec = x[1, 4]
535+
return new_vec.sum()
536+
537+
torch_module = SumModel().eval()
538+
539+
raw_data = np.random.rand(10, 10, 10).astype("float32")
540+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
541+
542+
515543
if __name__ == "__main__":
516544
tvm.testing.main()

0 commit comments

Comments
 (0)