Skip to content

Commit 8b90501

Browse files
authored
fix: ensure printing of wrapped ConcreteRArrays goes through our show (EnzymeAD#367)
* fix: ensure printing of wrapped ConcreteRArrays goes through our show * fix: allow wrapped arrays in mapreduce
1 parent ea97be3 commit 8b90501

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

src/ConcreteRArray.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,18 @@ function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
195195
return nothing
196196
end
197197

198-
function Base.print_array(io::IO, X::ConcreteRArray)
199-
if X.data == XLA.AsyncEmptyBuffer
198+
function Base.print_array(io::IO, X::AnyConcreteRArray)
199+
data = ancestor(X).data
200+
if data == XLA.AsyncEmptyBuffer
200201
println(io, "<Empty buffer>")
201202
return nothing
202203
end
203204
return Base.print_array(io, convert(Array, X))
204205
end
205206

206-
function Base.show(io::IO, X::ConcreteRArray)
207-
if X.data == XLA.AsyncEmptyBuffer
207+
function Base.show(io::IO, X::AnyConcreteRArray)
208+
data = ancestor(X).data
209+
if data == XLA.AsyncEmptyBuffer
208210
println(io, "<Empty buffer>")
209211
return nothing
210212
end

src/Reactant.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ function Enzyme.make_zero(
8585
return res
8686
end
8787

88+
function ancestor(x::AbstractArray)
89+
p_x = parent(x)
90+
p_x === x && return x
91+
return ancestor(p_x)
92+
end
93+
8894
include("mlir/MLIR.jl")
8995
include("XLA.jl")
9096
include("Interpreter.jl")

src/TracedRArray.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ function set_mlir_data!(x::AnyTracedRArray, data)
122122
return x
123123
end
124124

125-
ancestor(x::TracedRArray) = x
126-
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))
127-
128125
get_ancestor_indices(::TracedRArray, indices...) = indices
129126
function get_ancestor_indices(x::WrappedTracedRArray, indices...)
130127
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
@@ -388,10 +385,12 @@ end
388385
function Base.mapreduce(
389386
@nospecialize(f),
390387
@nospecialize(op),
391-
@nospecialize(A::TracedRArray{T,N});
388+
@nospecialize(A::AnyTracedRArray{T,N});
392389
dims=:,
393390
init=nothing,
394391
) where {T,N}
392+
A = materialize_traced_array(A)
393+
395394
if dims isa Int
396395
dims = [dims]
397396
end

0 commit comments

Comments
 (0)