Skip to content

Commit 45f01bd

Browse files
authored
make_zero(!) bugfixes and improved tests (#1961)
* Fix make_zero(!) bugs * Add make_zero(!) tests Aiming for full coverage of both new and old implementations of make_zero(!) * Fix more make_zero(!) bugs and add more tests * Improve make_zero! error message * Simplify likely dead branch * Reinstate single-arg StaticArrays methods
1 parent 06e791e commit 45f01bd

File tree

5 files changed

+912
-130
lines changed

5 files changed

+912
-130
lines changed

ext/EnzymeStaticArraysExt.jl

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,50 @@ end
3232
end
3333
end
3434

35-
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
36-
return Base.zero(x)
35+
@inline function Enzyme.EnzymeCore.make_zero(
36+
prev::FT
37+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}}
38+
return Base.zero(prev)::FT
3739
end
38-
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
39-
return Base.zero(x)
40+
@inline function Enzyme.EnzymeCore.make_zero(
41+
prev::FT
42+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
43+
return Base.zero(prev)::FT
44+
end
45+
46+
@inline function Enzyme.EnzymeCore.make_zero(
47+
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
48+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
49+
return Base.zero(prev)::FT
50+
end
51+
@inline function Enzyme.EnzymeCore.make_zero(
52+
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
53+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
54+
if haskey(seen, prev)
55+
return seen[prev]
56+
end
57+
new = Base.zero(prev)::FT
58+
seen[prev] = new
59+
return new
60+
end
61+
62+
@inline function Enzyme.EnzymeCore.make_zero!(
63+
prev::FT, seen
64+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
65+
if !isnothing(seen)
66+
if prev in seen
67+
return nothing
68+
end
69+
push!(seen, prev)
70+
end
71+
fill!(prev, zero(T))
72+
return nothing
73+
end
74+
@inline function Enzyme.EnzymeCore.make_zero!(
75+
prev::FT
76+
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
77+
Enzyme.EnzymeCore.make_zero!(prev, nothing)
78+
return nothing
4079
end
4180

4281
end

0 commit comments

Comments
 (0)