Skip to content

Commit d7337c9

Browse files
committed
fix: trace_type
1 parent 0d7ad84 commit d7337c9

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/Tracing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode
183183
end
184184
end
185185

186-
function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedRArray,mode}
186+
function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedTypes,mode}
187187
if mode == ConcreteToTraced
188188
throw("TracedRArray $T cannot be traced")
189189
elseif mode == TracedToConcrete
@@ -203,7 +203,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:XLAArray}
203203
end
204204

205205
function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode}
206-
if mode == ArrayToConcrete && T <: AbstractFloat
206+
if mode == ArrayToConcrete && T <: ReactantPrimitives
207207
return ConcreteRArray{T,N}
208208
else
209209
return Array{traced_type(T, seen, Val(mode)),N}

test/compile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
77
@testset "create_result" begin
88
@testset "NamedTuple" begin
99
x = (; a=rand(4, 3))
10-
x2 = (; a=Reactant.ConcreteRArray(x.a))
10+
x2 = Reactant.to_rarray(x)
1111

1212
f = @compile sum(x2)
1313

0 commit comments

Comments
 (0)