Skip to content

Commit 5e7bafc

Browse files
committed
Fix forward AD
1 parent 47135d4 commit 5e7bafc

File tree

2 files changed

+48
-39
lines changed

2 files changed

+48
-39
lines changed

src/ad.jl

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,61 @@
11
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
22
f = prob.f
33
p = value(prob.p)
4-
54
u0 = value(prob.u0)
65
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
76

87
sol = solve(newprob, alg, args...; kwargs...)
98

109
uu = sol.u
11-
if p isa Number
12-
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
13-
else
14-
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
15-
end
10+
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
11+
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)
12+
13+
z_arr = -inv(f_x) * f_p
1614

17-
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
1815
pp = prob.p
19-
sumfun = let f_x′ = -f_x
20-
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
16+
sumfun = ((z, p),) -> [zᵢ * ForwardDiff.partials(p) for zᵢ in z]
17+
if uu isa Number
18+
partials = sum(sumfun, zip(z_arr, pp))
19+
else
20+
partials = sum(sumfun, zip(eachcol(z_arr), pp))
2121
end
22-
partials = sum(sumfun, zip(f_p, pp))
22+
2323
return sol, partials
2424
end
2525

26-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
27-
<:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
26+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
27+
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
2828
kwargs...) where {iip, T, V, P}
2929
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
30-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
31-
sol.retcode)
30+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
31+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
3232
end
3333

34-
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
35-
<:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
34+
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},
35+
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
3636
kwargs...) where {iip, T, V, P}
3737
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
38-
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
39-
sol.retcode)
38+
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
39+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
40+
end
41+
42+
function scalar_nlsolve_∂f_∂p(f, u, p)
43+
ff = p isa Number ? ForwardDiff.derivative :
44+
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
45+
return ff(Base.Fix1(f, u), p)
46+
end
47+
48+
function scalar_nlsolve_∂f_∂u(f, u, p)
49+
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
50+
return ff(Base.Fix2(f, p), u)
51+
end
52+
53+
function scalar_nlsolve_dual_soln(u::Number, partials,
54+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
55+
return Dual{T, V, P}(u, partials[1])
56+
end
57+
58+
function scalar_nlsolve_dual_soln(u::AbstractArray, partials,
59+
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
60+
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))
4061
end

test/basictests.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,13 @@ end
5757
@test (@ballocated solve!($cache)) 64
5858
end
5959

60-
# FIXME: Even the previous tests were broken, but due to a typo in the tests they
61-
# accidentally passed
6260
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
6361
@test begin
6462
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
6563
res_true = sqrt(p)
6664
all(res.u .≈ res_true)
6765
end
68-
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
66+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
6967
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
7068
end
7169

@@ -101,11 +99,9 @@ end
10199
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
102100
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
103101

104-
probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
105-
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
102+
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
106103
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
107-
AutoSparseZygote(),
108-
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
104+
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
109105
probN = NonlinearProblem(quadratic_f, u0, 2.0)
110106
@test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0))
111107
end
@@ -149,8 +145,6 @@ end
149145
@test (@ballocated solve!($cache)) 64
150146
end
151147

152-
# FIXME: Even the previous tests were broken, but due to a typo in the tests they
153-
# accidentally passed
154148
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
155149
p in 1.0:0.1:100.0
156150

@@ -160,7 +154,7 @@ end
160154
res_true = sqrt(p)
161155
all(res.u .≈ res_true)
162156
end
163-
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
157+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
164158
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
165159
end
166160

@@ -204,11 +198,9 @@ end
204198
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
205199
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
206200

207-
probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
208-
@testset "ADType: $(autodiff) u0: $(u0) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
201+
@testset "ADType: $(autodiff) u0: $(_nameof(u0)) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
209202
true, AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
210-
AutoSparseZygote(), AutoSparseEnzyme()),
211-
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]),
203+
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0]),
212204
radius_update_scheme in radius_update_schemes
213205

214206
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -302,15 +294,13 @@ end
302294
@test (@ballocated solve!($cache)) 64
303295
end
304296

305-
# FIXME: Even the previous tests were broken, but due to a typo in the tests they
306-
# accidentally passed
307297
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
308298
@test begin
309299
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
310300
res_true = sqrt(p)
311301
all(res.u .≈ res_true)
312302
end
313-
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
303+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
314304
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
315305
end
316306

@@ -330,11 +320,9 @@ end
330320
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p)
331321
ForwardDiff.jacobian(t, p)
332322

333-
probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
334-
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
323+
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
335324
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
336-
AutoSparseZygote(),
337-
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
325+
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
338326
probN = NonlinearProblem(quadratic_f, u0, 2.0)
339327
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
340328
end

0 commit comments

Comments
 (0)