Skip to content

Commit 43123f6

Browse files
junjiang-labcopybara-github
authored andcommitted
Relax tfl.slice and strided_slice op check to allow 6D inputs.
PiperOrigin-RevId: 805891454
1 parent cca973d commit 43123f6

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

ai_edge_torch/_convert/test/test_convert.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,50 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
636636
self.fail(f"Conversion failed with 6d inputs: {err}")
637637
# pylint: enable=broad-except
638638

639+
@googletest.skipIf(
640+
ai_edge_torch.config.in_oss,
641+
reason="wait until dependencies are released to tf-nightly",
642+
)
643+
def test_convert_model_with_slice_6d_inputs(self):
644+
"""Test converting a simple model with slice and 6d inputs."""
645+
646+
class SampleModel(nn.Module):
647+
648+
def forward(self, x: torch.Tensor):
649+
return x[0:1, 0:2, 0:3, 0:4, 0:5, 0:1]
650+
651+
model = SampleModel().eval()
652+
args = (torch.randn((1, 2, 3, 4, 5, 6)),)
653+
654+
try:
655+
# Expect this to fix the error during conversion
656+
ai_edge_torch.convert(model, args)
657+
except Exception as err:
658+
self.fail(f"Conversion failed with 6d inputs for slice: {err}")
659+
# pylint: enable=broad-except
660+
661+
@googletest.skipIf(
662+
ai_edge_torch.config.in_oss,
663+
reason="wait until dependencies are released to tf-nightly",
664+
)
665+
def test_convert_model_with_strided_slice_6d_inputs(self):
666+
"""Test converting a simple model with strided_slice and 6d inputs."""
667+
668+
class SampleModel(nn.Module):
669+
670+
def forward(self, x: torch.Tensor):
671+
return x[:, :, :, :, :, ::2]
672+
673+
model = SampleModel().eval()
674+
args = (torch.randn((1, 2, 3, 4, 5, 6)),)
675+
676+
try:
677+
# Expect this to fix the error during conversion
678+
ai_edge_torch.convert(model, args)
679+
except Exception as err:
680+
self.fail(f"Conversion failed with 6d inputs for strided_slice: {err}")
681+
# pylint: enable=broad-except
682+
639683
def test_compile_model(self):
640684
"""Tests AOT compilation of a simple Add module."""
641685

0 commit comments

Comments
 (0)