Skip to content

Commit 781aacc

Browse files
authored
Fix fwd custom rule handler (#2573)
* Fix fwd custom rule handler * fix * fix * fix enzymetestutils * fix * la * fix * fix cr import
1 parent 0e70137 commit 781aacc

File tree

8 files changed

+174
-86
lines changed

8 files changed

+174
-86
lines changed

ext/EnzymeChainRulesCoreExt.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ function Enzyme._import_frule(fn, tys...)
6060
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval
6161
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...)
6262
if RetAnnotation <: Const
63-
return cres[2]::eltype(RetAnnotation)
63+
if EnzymeRules.needs_primal(config)
64+
return cres[1]::eltype(RetAnnotation)
65+
else
66+
return nothing
67+
end
6468
elseif RetAnnotation <: Duplicated
6569
return Duplicated(cres[1], cres[2])
6670
elseif RetAnnotation <: DuplicatedNoNeed
@@ -75,7 +79,11 @@ function Enzyme._import_frule(fn, tys...)
7579
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
7680
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
7781
end
78-
return cres[1][2]::eltype(RetAnnotation) # nothing
82+
if EnzymeRules.needs_primal(config)
83+
return cres[1][1]::eltype(RetAnnotation) # nothing
84+
else
85+
return nothing
86+
end
7987
elseif RetAnnotation <: BatchDuplicated
8088
cres1 = begin
8189
i = 1

lib/EnzymeTestUtils/src/finite_difference_calls.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ function _fd_forward(fdm, f, rettype, y, activities)
3131
if rettype <: Union{Duplicated,DuplicatedNoNeed}
3232
all(ignores) && return zero_tangent(y)
3333
sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores])
34-
ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec,
35-
(sig_arg_val_vec, sig_arg_dval_vec))
34+
ret_deval_vec = FiniteDifferences._jvp(fdm, f_vec,
35+
sig_arg_val_vec, sig_arg_dval_vec)
3636
return from_vec_out(ret_deval_vec)
3737
elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed}
3838
all(ignores) && return (var"1"=zero_tangent(y),)
3939
ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals...
4040
sig_args_dvals_vec, _ = to_vec(sig_args_dvals)
41-
ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec,
42-
(sig_arg_val_vec, sig_args_dvals_vec))
41+
ret_dval_vec = FiniteDifferences._jvp(fdm, f_vec,
42+
sig_arg_val_vec, sig_args_dvals_vec)
4343
return from_vec_out(ret_dval_vec)
4444
end
4545
return NamedTuple{ntuple(Symbol, length(ret_dvals))}(ret_dvals)
@@ -59,6 +59,18 @@ function multi_tovec(active_return, vals)
5959
end
6060
end
6161

62+
function j′vp(fdm, f_vec, ȳ, x)
63+
mat = transpose(first(FiniteDifferences.jacobian(fdm, f_vec, x)))
64+
result = zero(x)
65+
for i in 1:length(ȳ)
66+
tp = @inbounds ȳ[i]
67+
if isfinite(tp) && !iszero(tp)
68+
result .+= mat[:, i] .* tp
69+
end
70+
end
71+
return result
72+
end
73+
6274
#=
6375
_fd_reverse(fdm, f, ȳ, activities, active_return)
6476
@@ -98,13 +110,12 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return)
98110
if !is_batch
99111
ȳ_extended = (ȳ, s̄igargs...)
100112
ȳ_extended_vec = multi_tovec(active_return, ȳ_extended)
101-
fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec))
113+
fd_vec = j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)
102114
fd = from_vec_in(fd_vec)
103115
else
104116
fd = Tuple(zip(map(ȳ, s̄igargs...) do ȳ_extended...
105117
ȳ_extended_vec = multi_tovec(active_return, ȳ_extended)
106-
fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec,
107-
sigargs_vec))
118+
fd_vec = j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)
108119
return from_vec_in(fd_vec)
109120
end...))
110121
end

lib/EnzymeTestUtils/src/generate_tangent.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using LinearAlgebra
2+
13
# recursively apply f to all fields of x for which f is implemented; all other fields are
24
# left unchanged
35
function map_fields_recursive(f, x::T...) where {T}
@@ -13,6 +15,16 @@ function map_fields_recursive(f, x::T...) where {T<:Union{Array,Tuple,NamedTuple
1315
map_fields_recursive(f, xi...)
1416
end
1517
end
18+
function map_fields_recursive(f::typeof(Base.copyto!), y::T, x::T) where {T<:LinearAlgebra.HermOrSym{<:Number}}
19+
copyto!(x.uplo == 'U' ? UpperTriangular(parent(y)) : LowerTriangular(parent(y)), x.uplo == 'U' ? UpperTriangular(parent(x)) : LowerTriangular(parent(x)))
20+
return y
21+
end
22+
function map_fields_recursive(f::typeof(Base.copyto!), y::T, x::T) where {T<:AbstractFloat}
23+
return x
24+
end
25+
function map_fields_recursive(f::typeof(Base.copyto!), y::T, x::T) where {T<:Complex}
26+
return x
27+
end
1628
map_fields_recursive(f, x::T...) where {T<:AbstractFloat} = f(x...)
1729
map_fields_recursive(f, x::Array{<:Number}...) = f(x...)
1830

@@ -22,7 +34,12 @@ function rand_tangent(rng, x)
2234
T = eltype(v)
2335
# make numbers prettier sometimes when errors are printed.
2436
v_new = rand(rng, -9:T(0.01):9, length(v))
25-
return from_vec(v_new)
37+
rand_v = from_vec(v_new)
38+
if x isa Number
39+
return rand_v
40+
end
41+
zero_v = from_vec(zero(v))
42+
return map_fields_recursive(Base.copyto!, zero_v, rand_v)
2643
end
2744

2845
# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share

lib/EnzymeTestUtils/src/test_approx.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ function test_approx(x::Array{<:Number}, y::Array{<:Number}, msg; kwargs...)
77
@test_msg msg isapprox(x, y; kwargs...)
88
return nothing
99
end
10-
@static if VERSION < v"1.11-"
11-
else
1210
using LinearAlgebra
1311
function zero_copy(x)
1412
y = zero(parent(x))
@@ -24,7 +22,6 @@ function test_approx(x::LinearAlgebra.HermOrSym{<:Number}, y::LinearAlgebra.Herm
2422
test_approx(x2, y2, msg; kwargs...)
2523
return nothing
2624
end
27-
end
2825
function test_approx(x::AbstractArray{<:Number}, y::AbstractArray{<:Number}, msg; kwargs...)
2926
@test_msg msg isapprox(x, y; kwargs...)
3027
# for custom array types, fields should also match

lib/EnzymeTestUtils/test/helpers.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,12 @@ function f_structured_array(x::Hermitian)
3939
end
4040
return y
4141
end
42+
43+
function f_structured_nan(x::Hermitian)
44+
new = Matrix{Float32}(undef, 2, 2)
45+
new[1,1] = parent(x)[1,1]
46+
new[1,2] = parent(x)[1,2]
47+
new[2,1] = NaN
48+
new[2,2] = parent(x)[2,2]
49+
return Hermitian(new)
50+
end

lib/EnzymeTestUtils/test/test_forward.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ end
120120
end
121121
end
122122
end
123+
124+
@testset "structured NaN array inputs/outputs" begin
125+
@testset for Tret in (Const, Duplicated, BatchDuplicated),
126+
Tx in (Const, Duplicated, BatchDuplicated)
127+
128+
# if some are batch, none must be duplicated
129+
are_activities_compatible(Tret, Tx) || continue
130+
131+
x = Hermitian(Float32[1 2; 3 4])
132+
133+
atol = rtol = 0.01
134+
test_forward(f_structured_nan, Tret, (x, Tx); atol, rtol)
135+
end
136+
end
123137

124138
@testset "structured array inputs/outputs" begin
125139
@testset for Tret in (Const, Duplicated, BatchDuplicated),

lib/EnzymeTestUtils/test/test_reverse.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,22 @@ end
9090
end
9191
end
9292
end
93+
94+
@testset "structured NaN array inputs/outputs" begin
95+
@testset for Tret in (Const, Duplicated, BatchDuplicated),
96+
Tx in (Const, Duplicated, BatchDuplicated)
9397

94-
VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin
98+
# if some are batch, none must be duplicated
99+
are_activities_compatible(Tret, Tx) || continue
100+
101+
x = Hermitian(Float32[1 2; 3 4])
102+
103+
atol = rtol = 0.01
104+
test_reverse(f_structured_nan, Tret, (x, Tx); atol, rtol)
105+
end
106+
end
107+
108+
@testset "structured array inputs/outputs" begin
95109
@testset for Tret in (Const, Duplicated, BatchDuplicated),
96110
Tx in (Const, Duplicated, BatchDuplicated),
97111
T in (Float32, Float64, ComplexF32, ComplexF64)

0 commit comments

Comments
 (0)