Skip to content

Commit 0ec1674

Browse files
committed
wip
Signed-off-by: Justin Chu <[email protected]>
1 parent 75f8a51 commit 0ec1674

File tree

1 file changed

+13
-19
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+13
-19
lines changed

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,47 +178,41 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
178178

179179
@torch_op("prims::broadcast_in_dim", trace_only=True)
180180
def prims_broadcast_in_dim(
181-
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
181+
a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int]
182182
) -> TensorType:
183183
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
184-
185-
# Simplified approach that replaces ScatterElements with more basic operations
186-
# while still leveraging compile-time knowledge of broadcast_dimensions
187-
188-
input_shape = op.Shape(a)
184+
189185
target_rank = len(shape)
190-
186+
191187
if not broadcast_dimensions:
192-
# Special case: no broadcast dimensions - all target dims should be 1
193-
ones = op.ConstantOfShape(op.Constant(value_ints=[target_rank]), value=op.Constant(value_int=1))
194-
reshaped = op.Reshape(a, ones)
195-
return op.Expand(reshaped, shape)
196-
188+
# Special case: no broadcast dimensions - all target dims should be 1
189+
return op.Expand(a, common_ops.merge_dims(shape))
190+
197191
# Build intermediate shape using a simpler approach than ScatterElements
198192
# We'll construct it by concatenating the right values for each position
199-
193+
200194
# Create base shape of all 1s
201195
ones = [1] * target_rank
202-
196+
203197
# For each broadcast dimension, we'll replace the 1 with the actual input dimension
204198
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
205199
intermediate_shape = ones
206-
200+
207201
for i, broadcast_dim in enumerate(broadcast_dimensions):
208202
# Get the input dimension value
209-
input_dim_value = op.Gather(input_shape, op.Constant(value_int=i))
210-
203+
input_dim_value = op.Shape(a, start=i, end=i + 1)
204+
211205
# Create a one-hot mask for this position
212206
indices = op.Range(op.Constant(value_int=0), op.Constant(value_int=target_rank), op.Constant(value_int=1))
213207
mask = op.Equal(indices, op.Constant(value_int=broadcast_dim))
214-
208+
215209
# Use Where to replace the 1 with the input dimension value at this position
216210
intermediate_shape = op.Where(
217211
mask,
218212
op.Cast(input_dim_value, to=ir.TensorType.INT64),
219213
intermediate_shape
220214
)
221-
215+
222216
# Reshape input to intermediate shape and expand to target
223217
reshaped = op.Reshape(a, intermediate_shape)
224218
return op.Expand(reshaped, shape)

0 commit comments

Comments
 (0)