@@ -188,9 +188,6 @@ def prims_broadcast_in_dim(
188188 # Special case: no broadcast dimensions - all target dims should be 1
189189 return op .Expand (a , common_ops .merge_dims (shape ))
190190
191- # Build intermediate shape using a simpler approach than ScatterElements
192- # We'll construct it by concatenating the right values for each position
193-
194191 # Create base shape of all 1s
195192 ones = [1 ] * target_rank
196193
@@ -201,20 +198,10 @@ def prims_broadcast_in_dim(
201198 for i , broadcast_dim in enumerate (broadcast_dimensions ):
202199 # Get the input dimension value
203200 input_dim_value = op .Shape (a , start = i , end = i + 1 )
204-
205- # Create a one-hot mask for this position
206- indices = op .Range (op .Constant (value_int = 0 ), op .Constant (value_int = target_rank ), op .Constant (value_int = 1 ))
207- mask = op .Equal (indices , op .Constant (value_int = broadcast_dim ))
208-
209- # Use Where to replace the 1 with the input dimension value at this position
210- intermediate_shape = op .Where (
211- mask ,
212- op .Cast (input_dim_value , to = ir .TensorType .INT64 ),
213- intermediate_shape
214- )
201+ intermediate_shape [broadcast_dim ] = input_dim_value
215202
216203 # Reshape input to intermediate shape and expand to target
217- reshaped = op .Reshape (a , intermediate_shape )
204+ reshaped = op .Reshape (a , common_ops . merge_dims ( intermediate_shape ) )
218205 return op .Expand (reshaped , shape )
219206
220207
0 commit comments