@@ -178,47 +178,41 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
178178
179179@torch_op ("prims::broadcast_in_dim" , trace_only = True )
180180def prims_broadcast_in_dim (
181- a : TensorType , shape : INT64 , broadcast_dimensions : Sequence [int ]
181+ a : TensorType , shape : Sequence [ INT64 ] , broadcast_dimensions : Sequence [int ]
182182) -> TensorType :
183183 """broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
184-
185- # Simplified approach that replaces ScatterElements with more basic operations
186- # while still leveraging compile-time knowledge of broadcast_dimensions
187-
188- input_shape = op .Shape (a )
184+
189185 target_rank = len (shape )
190-
186+
191187 if not broadcast_dimensions :
192- # Special case: no broadcast dimensions - all target dims should be 1
193- ones = op .ConstantOfShape (op .Constant (value_ints = [target_rank ]), value = op .Constant (value_int = 1 ))
194- reshaped = op .Reshape (a , ones )
195- return op .Expand (reshaped , shape )
196-
188+ # Special case: no broadcast dimensions - all target dims should be 1
189+ return op .Expand (a , common_ops .merge_dims (shape ))
190+
197191 # Build intermediate shape using a simpler approach than ScatterElements
198192 # We'll construct it by concatenating the right values for each position
199-
193+
200194 # Create base shape of all 1s
201195 ones = [1 ] * target_rank
202-
196+
203197 # For each broadcast dimension, we'll replace the 1 with the actual input dimension
204198 # Since broadcast_dimensions is compile-time known, we can do this with individual operations
205199 intermediate_shape = ones
206-
200+
207201 for i , broadcast_dim in enumerate (broadcast_dimensions ):
208202 # Get the input dimension value
209- input_dim_value = op .Gather ( input_shape , op . Constant ( value_int = i ) )
210-
203+ input_dim_value = op .Shape ( a , start = i , end = i + 1 )
204+
211205 # Create a one-hot mask for this position
212206 indices = op .Range (op .Constant (value_int = 0 ), op .Constant (value_int = target_rank ), op .Constant (value_int = 1 ))
213207 mask = op .Equal (indices , op .Constant (value_int = broadcast_dim ))
214-
208+
215209 # Use Where to replace the 1 with the input dimension value at this position
216210 intermediate_shape = op .Where (
217211 mask ,
218212 op .Cast (input_dim_value , to = ir .TensorType .INT64 ),
219213 intermediate_shape
220214 )
221-
215+
222216 # Reshape input to intermediate shape and expand to target
223217 reshaped = op .Reshape (a , intermediate_shape )
224218 return op .Expand (reshaped , shape )
0 commit comments