Skip to content

Commit 784beb5

Browse files
Fix enzyme.make_zero! (#1544)
* Fix enzyme.make_zero * fix * with test * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * allowscalar --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ea444db commit 784beb5

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/Enzyme.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,46 @@ const enzyme_dupnoneed = 3
55
const enzyme_outnoneed = 4
66
const enzyme_constnoneed = 5
77

8+
@inline function Enzyme.make_zero(x::RNumber)
9+
return zero(Core.Typeof(x))
10+
end
11+
12+
@inline function Enzyme.make_zero(x::RArray{FT,N})::RArray{FT,N} where {FT<:AbstractFloat,N}
13+
return Base.zero(x)
14+
end
15+
16+
@inline function Enzyme.make_zero(
17+
x::RArray{Complex{FT},N}
18+
)::RArray{Complex{FT},N} where {FT<:AbstractFloat,N}
19+
return Base.zero(x)
20+
end
21+
22+
macro register_make_zero_inplace(sym)
23+
quote
24+
@inline function $sym(prev::RArray{T,N})::Nothing where {T<:AbstractFloat,N}
25+
$sym(prev, nothing)
26+
return nothing
27+
end
28+
29+
@inline function $sym(prev::RArray{T,N}, seen::ST)::Nothing where {T,N,ST}
30+
if Enzyme.Compiler.guaranteed_const_nongen(T, nothing)
31+
return nothing
32+
end
33+
if !isnothing(seen)
34+
if prev in seen
35+
return nothing
36+
end
37+
push!(seen, prev)
38+
end
39+
fill!(prev, zero(T))
40+
return nothing
41+
end
42+
end
43+
end
44+
45+
@register_make_zero_inplace(Enzyme.make_zero!)
46+
@register_make_zero_inplace(Enzyme.remake_zero!)
47+
848
function Enzyme.make_zero(
949
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
1050
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}

test/autodiff.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,13 @@ vector_forward_ad(x) = Enzyme.autodiff(Forward, fn, BatchDuplicated(x, Enzyme.on
231231
@test res[1][4] res_enz[1][4]
232232
end
233233

234+
@testset "make_zero!" begin
235+
x = Reactant.to_rarray([3.1])
236+
@jit Enzyme.make_zero!(x)
237+
238+
@test @allowscalar x[1] 0.0
239+
end
240+
234241
function simple_forward(x, st)
235242
rng = copy(st.rng)
236243
y = similar(x)

0 commit comments

Comments
 (0)