Skip to content

Commit 158e6cd

Browse files
committed
prims_broadcast_in_dim
Signed-off-by: Justin Chu <[email protected]>
1 parent 0ec1674 commit 158e6cd

File tree

1 file changed

+2
-15
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+2
-15
lines changed

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,6 @@ def prims_broadcast_in_dim(
188188
# Special case: no broadcast dimensions - all target dims should be 1
189189
return op.Expand(a, common_ops.merge_dims(shape))
190190

191-
# Build intermediate shape using a simpler approach than ScatterElements
192-
# We'll construct it by concatenating the right values for each position
193-
194191
# Create base shape of all 1s
195192
ones = [1] * target_rank
196193

@@ -201,20 +198,10 @@ def prims_broadcast_in_dim(
201198
for i, broadcast_dim in enumerate(broadcast_dimensions):
202199
# Get the input dimension value
203200
input_dim_value = op.Shape(a, start=i, end=i + 1)
204-
205-
# Create a one-hot mask for this position
206-
indices = op.Range(op.Constant(value_int=0), op.Constant(value_int=target_rank), op.Constant(value_int=1))
207-
mask = op.Equal(indices, op.Constant(value_int=broadcast_dim))
208-
209-
# Use Where to replace the 1 with the input dimension value at this position
210-
intermediate_shape = op.Where(
211-
mask,
212-
op.Cast(input_dim_value, to=ir.TensorType.INT64),
213-
intermediate_shape
214-
)
201+
intermediate_shape[broadcast_dim] = input_dim_value
215202

216203
# Reshape input to intermediate shape and expand to target
217-
reshaped = op.Reshape(a, intermediate_shape)
204+
reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape))
218205
return op.Expand(reshaped, shape)
219206

220207

0 commit comments

Comments
 (0)