@@ -1693,20 +1693,42 @@ def _index_tensor(self, node: fx.Node) -> relax.Var:
16931693 axis , index_tensor = non_none_indices [0 ]
16941694 return self .block_builder .emit (relax .op .take (data , index_tensor , axis = axis ))
16951695
1696- # General case: multiple non-None indices require advanced indexing
1696+ # Check if all indices can be squeezed to 1D for sequential take
1697+ def is_squeezable (idx ):
1698+ if idx .struct_info .ndim == 1 :
1699+ return True
1700+ if idx .struct_info .ndim == 2 :
1701+ shape = idx .struct_info .shape
1702+ for d in shape :
1703+ if isinstance (d , int ) and d == 1 :
1704+ return True
1705+ # Check for tir.IntImm
1706+ if hasattr (d , "value" ) and d .value == 1 :
1707+ return True
1708+ return False
1709+
1710+ all_squeezable = all (is_squeezable (idx ) for _ , idx in non_none_indices )
1711+ if all_squeezable :
1712+ result = data
1713+ for axis , idx in reversed (non_none_indices ):
1714+ if idx .struct_info .ndim > 1 :
1715+ idx = self .block_builder .emit (relax .op .squeeze (idx ))
1716+ result = self .block_builder .emit (relax .op .take (result , idx , axis = axis ))
1717+ return result
1718+
1719+ # General case: replace None with arange, reshaped for broadcasting
1720+ max_ndim = max ((idx .struct_info .ndim for _ , idx in non_none_indices ), default = 1 )
16971721 processed_indices = []
16981722 data_shape = self .shape_of (data )
16991723
17001724 for i , idx in enumerate (indices ):
17011725 if idx is None :
1702- dim_size = data_shape [i ]
17031726 arange_idx = self .block_builder .emit (
1704- relax .op .arange (
1705- start = relax .PrimValue (0 ),
1706- end = dim_size ,
1707- step = relax .PrimValue (1 ),
1708- dtype = "int64" ,
1709- )
1727+ relax .op .arange (relax .PrimValue (0 ), data_shape [i ], relax .PrimValue (1 ), "int64" )
1728+ )
1729+ # Reshape to [dim_size, 1, 1, ...] for broadcasting
1730+ arange_idx = self .block_builder .emit (
1731+ relax .op .reshape (arange_idx , [data_shape [i ]] + [1 ] * (max_ndim - 1 ))
17101732 )
17111733 processed_indices .append (arange_idx )
17121734 else :
0 commit comments