Skip to content

Commit 9825c36

Browse files
committed
add new tests to improve coverage
1 parent 7af0d47 commit 9825c36

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,36 @@ def forward(self, tokens, h, c):
596596
onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False)
597597
_testing.assert_onnx_program(onnx_program)
598598

599+
def test_unbind_dynamic_dim0(self):
600+
"""Test unbind with dynamic dimension 0 - triggers SplitToSequence"""
601+
602+
class UnbindModel(torch.nn.Module):
603+
def forward(self, x):
604+
tensors = torch.unbind(x, dim=0)
605+
return sum(tensors)
606+
607+
model = UnbindModel()
608+
x = torch.randn(3, 4, 5)
609+
onnx_program = torch.onnx.export(
610+
model, (x,), dynamo=True, verbose=False, dynamic_shapes=({0: "batch_size"},)
611+
)
612+
_testing.assert_onnx_program(onnx_program)
613+
614+
def test_unbind_dynamic_dim1(self):
615+
"""Test unbind with dynamic dimension 1 - triggers SplitToSequence"""
616+
617+
class UnbindModel(torch.nn.Module):
618+
def forward(self, x):
619+
tensors = torch.unbind(x, dim=1)
620+
return sum(tensors)
621+
622+
model = UnbindModel()
623+
x = torch.randn(2, 3, 4)
624+
onnx_program = torch.onnx.export(
625+
model, (x,), dynamo=True, verbose=False, dynamic_shapes=({1: "seq_len"},)
626+
)
627+
_testing.assert_onnx_program(onnx_program)
628+
599629

600630
if __name__ == "__main__":
601631
unittest.main()

0 commit comments

Comments
 (0)