Skip to content

Commit 36294b1

Browse files
committed
fix: minor fixes to support adjoints
1 parent 0cbc2fc commit 36294b1

10 files changed

+49
-8
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2020
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2121

2222
[weakdeps]
23+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2324
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2425
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2526

2627
[extensions]
28+
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
2729
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
2830
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
2931

@@ -33,6 +35,7 @@ ArrayInterface = "7.9"
3335
CommonSolve = "0.2.4"
3436
Compat = "4.15"
3537
ConcreteStructs = "0.2.3"
38+
DiffEqBase = "6.149"
3639
DifferentiationInterface = "0.6.1"
3740
EnzymeCore = "0.8"
3841
FastClosures = "0.3"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module NonlinearSolveBaseDiffEqBaseExt
2+
3+
using DiffEqBase: DiffEqBase
4+
using SciMLBase: remake
5+
6+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem
7+
8+
function DiffEqBase.get_concrete_problem(
9+
prob::ImmutableNonlinearProblem, isadapt; kwargs...)
10+
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
11+
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
12+
p = DiffEqBase.get_concrete_p(prob, kwargs)
13+
return remake(prob; u0 = u0, p = p)
14+
end
15+
16+
end

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ CUDA = "5.3"
4646
ChainRulesCore = "1.24"
4747
CommonSolve = "0.2.4"
4848
ConcreteStructs = "0.2.3"
49-
DiffEqBase = "6.155"
49+
DiffEqBase = "6.149"
5050
DifferentiationInterface = "0.6.1"
5151
Enzyme = "0.13"
5252
ExplicitImports = "1.9"
@@ -79,6 +79,7 @@ julia = "1.10"
7979
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
8080
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
8181
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
82+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8283
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
8384
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
8485
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -95,4 +96,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9596
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9697

9798
[targets]
98-
test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
99+
test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using DiffEqBase: DiffEqBase
44

55
using SimpleNonlinearSolve: SimpleNonlinearSolve
66

7+
SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true
8+
79
function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
810
return DiffEqBase._solve_adjoint(args...; kwargs...)
911
end

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
3232
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal...)
3333
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
3434
end
35+
36+
return Array(out), ∇simplenonlinearsolve_solve_up
3537
end
3638
end
3739

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
3131
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Tracker.data(Δ))
3232
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
3333
end
34+
35+
return out, ∇simplenonlinearsolve_solve_up
3436
end
3537
end
3638

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using FiniteDiff: FiniteDiff
2020
using ForwardDiff: ForwardDiff
2121

2222
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
23-
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM
23+
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
24+
nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
2425

2526
const DI = DifferentiationInterface
2627

@@ -47,6 +48,20 @@ function CommonSolve.solve(prob::NonlinearProblem,
4748
return solve(prob, alg, args...; kwargs...)
4849
end
4950

51+
function CommonSolve.solve(
52+
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
53+
<:Union{
54+
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
55+
alg::AbstractSimpleNonlinearSolveAlgorithm,
56+
args...;
57+
kwargs...) where {T, V, P, iip}
58+
prob = convert(ImmutableNonlinearProblem, prob)
59+
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
60+
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
61+
return SciMLBase.build_solution(
62+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
63+
end
64+
5065
function CommonSolve.solve(
5166
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
5267
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
@@ -59,9 +74,8 @@ function CommonSolve.solve(
5974
p === nothing, alg, args...; prob.kwargs..., kwargs...)
6075
end
6176

62-
function simplenonlinearsolve_solve_up(
63-
prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed,
64-
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
77+
function simplenonlinearsolve_solve_up(prob::ImmutableNonlinearProblem, sensealg, u0,
78+
u0_changed, p, p_changed, alg, args...; kwargs...)
6579
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
6680
return SciMLBase.__solve(prob, alg, args...; kwargs...)
6781
end

lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testitem "Simple Adjoint Test" tags=[:adjoint] begin
2-
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote
2+
using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase,
3+
SimpleNonlinearSolve
34

45
ff(u, p) = u .^ 2 .- p
56

lib/SimpleNonlinearSolve/test/core/allocation_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
@test true
3535
catch e
3636
@error e
37-
@test false broken = (alg isa SimpleHalley)
37+
@test false broken=(alg isa SimpleHalley)
3838
end
3939
end
4040
end

lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl

Whitespace-only changes.

0 commit comments

Comments
 (0)