@@ -596,6 +596,36 @@ def forward(self, tokens, h, c):
596596 onnx_program = torch .onnx .export (model , (tokens , h , c ), dynamo = True , verbose = False )
597597 _testing .assert_onnx_program (onnx_program )
598598
599+ def test_unbind_dynamic_dim0 (self ):
600+ """Test unbind with dynamic dimension 0 - triggers SplitToSequence"""
601+
602+ class UnbindModel (torch .nn .Module ):
603+ def forward (self , x ):
604+ tensors = torch .unbind (x , dim = 0 )
605+ return sum (tensors )
606+
607+ model = UnbindModel ()
608+ x = torch .randn (3 , 4 , 5 )
609+ onnx_program = torch .onnx .export (
610+ model , (x ,), dynamo = True , verbose = False , dynamic_shapes = ({0 : "batch_size" },)
611+ )
612+ _testing .assert_onnx_program (onnx_program )
613+
614+ def test_unbind_dynamic_dim1 (self ):
615+ """Test unbind with dynamic dimension 1 - triggers SplitToSequence"""
616+
617+ class UnbindModel (torch .nn .Module ):
618+ def forward (self , x ):
619+ tensors = torch .unbind (x , dim = 1 )
620+ return sum (tensors )
621+
622+ model = UnbindModel ()
623+ x = torch .randn (2 , 3 , 4 )
624+ onnx_program = torch .onnx .export (
625+ model , (x ,), dynamo = True , verbose = False , dynamic_shapes = ({1 : "seq_len" },)
626+ )
627+ _testing .assert_onnx_program (onnx_program )
628+
599629
600630if __name__ == "__main__" :
601631 unittest .main ()
0 commit comments