@@ -636,6 +636,50 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
636636 self .fail (f"Conversion failed with 6d inputs: { err } " )
637637 # pylint: enable=broad-except
638638
639+ @googletest .skipIf (
640+ ai_edge_torch .config .in_oss ,
641+ reason = "wait until dependencies are released to tf-nightly" ,
642+ )
643+ def test_convert_model_with_slice_6d_inputs (self ):
644+ """Test converting a simple model with slice and 6d inputs."""
645+
646+ class SampleModel (nn .Module ):
647+
648+ def forward (self , x : torch .Tensor ):
649+ return x [0 :1 , 0 :2 , 0 :3 , 0 :4 , 0 :5 , 0 :1 ]
650+
651+ model = SampleModel ().eval ()
652+ args = (torch .randn ((1 , 2 , 3 , 4 , 5 , 6 )),)
653+
654+ try :
655+ # Expect this to fix the error during conversion
656+ ai_edge_torch .convert (model , args )
657+ except Exception as err :
658+ self .fail (f"Conversion failed with 6d inputs for slice: { err } " )
659+ # pylint: enable=broad-except
660+
661+ @googletest .skipIf (
662+ ai_edge_torch .config .in_oss ,
663+ reason = "wait until dependencies are released to tf-nightly" ,
664+ )
665+ def test_convert_model_with_strided_slice_6d_inputs (self ):
666+ """Test converting a simple model with strided_slice and 6d inputs."""
667+
668+ class SampleModel (nn .Module ):
669+
670+ def forward (self , x : torch .Tensor ):
671+ return x [:, :, :, :, :, ::2 ]
672+
673+ model = SampleModel ().eval ()
674+ args = (torch .randn ((1 , 2 , 3 , 4 , 5 , 6 )),)
675+
676+ try :
677+ # Expect this to fix the error during conversion
678+ ai_edge_torch .convert (model , args )
679+ except Exception as err :
680+ self .fail (f"Conversion failed with 6d inputs for strided_slice: { err } " )
681+ # pylint: enable=broad-except
682+
639683 def test_compile_model (self ):
640684 """Tests AOT compilation of a simple Add module."""
641685
0 commit comments