1414import math
1515from typing import Any , Optional , Sequence , Tuple , Union
1616
17+ import numpy as np
18+
1719from onnxscript import (
1820 BFLOAT16 ,
1921 BOOL ,
@@ -4303,17 +4305,21 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43034305 while len (reshape_list ) > len (values_shape ) and 1 in reshape_list :
43044306 reshape_list .remove (1 )
43054307
4308+ # Or add ones until the rank of reshape_list matches values_shape.
4309+ while len (reshape_list ) < len (values_shape ):
4310+ reshape_list .append (1 )
4311+
43064312 # Now ensure each dimension is broadcastable:
43074313 # This is mandatory when mixing basic and advanced indexing
43084314 # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
43094315 # the reshape list should be : [[2, 1], [1, 3], [2, 1]]
43104316 for i , r in enumerate (reshape_list ):
4311- if r != 1 and r != values_shape [i ]:
4312- one_index = reshape_list .index (1 )
4317+ if r not in ( 1 , values_shape [i ]) :
4318+ value_index = values_shape .index (r )
43134319 # Swap elements
43144320 # For the example above the current reshape list is [1, 2] for last dim,
43154321 # to make it broadcastable, we swap the elements
4316- reshape_list [one_index ], reshape_list [i ] = reshape_list [ i ] , 1
4322+ reshape_list [value_index ], reshape_list [i ] = r , 1
43174323
43184324 return reshape_list
43194325
@@ -4322,8 +4328,8 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43224328 if len (indices ) < self_rank :
43234329 indices = list (indices ) + [None ] * (self_rank - len (indices ))
43244330
4325- # Get values shape (we use .numpy to make it hashable)
4326- values_shape = values .shape . numpy ( )
4331+ # Get values shape
4332+ values_shape = tuple ( values .shape )
43274333
43284334 index_vectors = []
43294335 for i in range (self_rank ):
@@ -4333,7 +4339,15 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
43334339 reshape_update = self .shape [i ]
43344340 else :
43354341 idx = indices [i ]
4336- reshape_update = indices [i ].shape [0 ]
4342+ reshape_update = np .prod (idx .shape ).item ()
4343+ # when Index is more than 1D, flatten it and also the values shape
4344+ # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
4345+ # Indices -> (2*4,) and values shape (2*4, 32)
4346+ if len (idx .shape ) > 1 :
4347+ values_shape = (reshape_update ,) + values_shape [len (idx .shape ) :]
4348+
4349+ # Flatten index (always working with 1D index in each dim)
4350+ idx = op .Reshape (idx , [- 1 ])
43374351
43384352 # Create a reshape pattern: one value per index dimension,
43394353 # with the current dimension set to the update size.
0 commit comments