Skip to content

Commit 302274f

Browse files
authored
fix: number tracing (EnzymeAD#849)
1 parent ede4493 commit 302274f

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/Tracing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Base.@nospecializeinfer function traced_type_inner(
5555
@nospecialize(track_numbers::Type),
5656
@nospecialize(sharding)
5757
)
58-
if Mode == ArrayToConcrete && T <: track_numbers
58+
if mode == ArrayToConcrete && T <: track_numbers
5959
return ConcretePJRTNumber{
6060
T,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0)
6161
}

test/tracing.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ struct Wrapper{A,B}
1515
b::B
1616
end
1717

18+
struct Descent{T}
19+
eta::T
20+
end
21+
22+
struct RMSProp{Teta,Trho,Teps,C<:Bool}
23+
eta::Teta
24+
rho::Trho
25+
epsilon::Teps
26+
centred::C
27+
end
28+
1829
@testset "Tracing" begin
1930
@testset "trace_type" begin
2031
@testset "mode = ConcreteToTraced" begin
@@ -242,4 +253,17 @@ end
242253
st_traced = Reactant.to_rarray(st; track_numbers=Number)
243254
@test st_traced.training isa Val{true}
244255
end
256+
257+
@testset "to_rarray(::AbstractRule)" begin
258+
opt = Descent(0.1)
259+
opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat)
260+
@test opt_traced.eta isa ConcreteRNumber{Float64}
261+
262+
opt = RMSProp(0.1, 0.9, 1e-8, true)
263+
opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat)
264+
@test opt_traced.eta isa ConcreteRNumber{Float64}
265+
@test opt_traced.rho isa ConcreteRNumber{Float64}
266+
@test opt_traced.epsilon isa ConcreteRNumber{Float64}
267+
@test opt_traced.centred isa Bool
268+
end
245269
end

0 commit comments

Comments
 (0)