Skip to content

Commit 91a4a00

Browse files
committed
fix: scalar broadcasting case
1 parent 7fd269d commit 91a4a00

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

src/TracedRArray.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -586,20 +586,8 @@ end
586586

587587
function broadcast_to_size(arg::TracedRNumber, rsize)
588588
length(rsize) == 0 && return arg
589-
mlirty = MLIR.IR.type(arg.mlir_data)
590-
return TracedRArray{eltype(arg),length(rsize)}(
591-
(),
592-
MLIR.IR.result(
593-
MLIR.Dialects.stablehlo.broadcast_in_dim(
594-
arg.mlir_data;
595-
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
596-
broadcast_dimensions=MLIR.IR.DenseArrayAttribute([
597-
Int64(i - 1) for i in rsize
598-
]),
599-
),
600-
1,
601-
),
602-
rsize,
589+
return broadcast_to_size_internal(
590+
TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize
603591
)
604592
end
605593

0 commit comments

Comments
 (0)