Skip to content

Commit e1b8d3b

Browse files
authored
Fix anytracedarray override (#1290)
* Fix anytracedarray override * fix * fix2 * add test
1 parent 5cb30f3 commit e1b8d3b

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/TracedUtils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ ReactantCore.materialize_traced_array(x::AbstractArray) = x
2222

2323
ReactantCore.materialize_traced_array(x::TracedRArray) = x
2424

25-
ReactantCore.materialize_traced_array(x::AnyTracedRArray) = x[axes(x)...]
26-
2725
function ReactantCore.materialize_traced_array(x::AbstractRange)
2826
return Reactant.aos_to_soa(collect(x))
2927
end
@@ -54,7 +52,11 @@ function ReactantCore.materialize_traced_array(
5452
end
5553

5654
function ReactantCore.materialize_traced_array(x::AbstractArray{TracedRNumber{T}}) where {T}
57-
return Reactant.aos_to_soa(x)
55+
as = Reactant.aos_to_soa(x)
56+
if as === x
57+
as = x[axes(x)...]
58+
end
59+
return ReactantCore.materialize_traced_array(as)
5860
end
5961

6062
get_mlir_data(x::TracedRNumber) = x.mlir_data

src/Types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474
@leaf TracedRArray
7575
Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N}
7676

77-
const AnyTracedRArray{T,N} = AbstractArray{<:TracedRNumber{T},N}
77+
const AnyTracedRArray{T,N} = AbstractArray{TracedRNumber{T},N}
7878
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
7979
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
8080
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}

test/tracing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ struct RMSProp{Teta,Trho,Teps,C<:Bool}
2626
centred::C
2727
end
2828

29+
@testset "Traced Type" begin
30+
@test !(Vector{Union{}} <: Reactant.AnyTracedRArray)
31+
end
32+
2933
@testset "Tracing" begin
3034
@testset "trace_type" begin
3135
@testset "mode = ConcreteToTraced" begin

0 commit comments

Comments
 (0)