Skip to content

Commit 8a289c9

Browse files
Call Ops.fill in similar for Broadcasted argument (#1305)
* call `Ops.fill` in `similar` for `Broadcasted` argument * fix * Update test/ops.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8032427 commit 8a289c9

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/TracedRArray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,14 +675,14 @@ function Base.similar(
675675
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
676676
) where {T<:Reactant.ReactantPrimitive,N}
677677
@assert N isa Int
678-
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
678+
return Ops.fill(zero(unwrapped_eltype(T)), dims)
679679
end
680680

681681
function Base.similar(
682682
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
683683
) where {T<:Reactant.ReactantPrimitive,N}
684684
@assert N isa Int
685-
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
685+
return Ops.fill(zero(T), dims)
686686
end
687687

688688
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}})

test/ops.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,9 @@ end
11351135
mod = @code_hlo optimize = false const_dedup(x)
11361136
hlo_ir = repr(mod)
11371137
csts = collect(x for x in eachsplit(hlo_ir, "\n") if occursin("stablehlo.constant", x))
1138+
# calls to similar give rise to dense<0> constants (that are not deduplicated):
1139+
csts = filter(x -> !occursin("dense<0>", x), csts)
1140+
11381141
@test length(csts) == 2
11391142
idx = findfirst(x -> occursin("1, 2, 3, 4", x), csts)
11401143
@test idx !== nothing

0 commit comments

Comments
 (0)