Skip to content

Commit 8153008

Browse files
committed
feat: add simplenonlinearsolve AD specific dispatches
1 parent c238914 commit 8153008

File tree

7 files changed

+168
-2
lines changed

7 files changed

+168
-2
lines changed

lib/NonlinearSolveBase/src/NonlinearSolveBase.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,4 @@ export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTermi
2727
AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
2828
RelNormSafeNormTerminationMode, AbsNormSafeNormTerminationMode
2929

30-
export ImmutableNonlinearProblem
31-
3230
end

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,35 @@ version = "1.13.0"
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
10+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1011
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1112
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1314
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1415

16+
[weakdeps]
17+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
18+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
19+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
20+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
21+
22+
[extensions]
23+
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
24+
SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase"
25+
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
26+
SimpleNonlinearSolveTrackerExt = "Tracker"
27+
1528
[compat]
1629
ADTypes = "1.2"
1730
ArrayInterface = "7.16"
1831
BracketingNonlinearSolve = "1"
32+
ChainRulesCore = "1.24"
33+
CommonSolve = "0.2.4"
34+
DiffEqBase = "6.155"
1935
NonlinearSolveBase = "1"
2036
PrecompileTools = "1.2"
2137
Reexport = "1.2"
38+
ReverseDiff = "1.15"
2239
SciMLBase = "2.50"
40+
Tracker = "0.2.35"
2341
julia = "1.10"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module SimpleNonlinearSolveChainRulesCoreExt
2+
3+
using ChainRulesCore: ChainRulesCore, NoTangent
4+
using NonlinearSolveBase: ImmutableNonlinearProblem
5+
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
6+
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
8+
solve_adjoint
9+
10+
function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
11+
prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem},
12+
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
13+
out, ∇internal = solve_adjoint(
14+
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
15+
function ∇simplenonlinearsolve_solve_up(Δ)
16+
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ)
17+
return (
18+
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
19+
end
20+
return out, ∇simplenonlinearsolve_solve_up
21+
end
22+
23+
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module SimpleNonlinearSolveDiffEqBaseExt
2+
3+
using DiffEqBase: DiffEqBase
4+
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
7+
function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...)
8+
return DiffEqBase._solve_adjoint(args...; kwargs...)
9+
end
10+
11+
end
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
module SimpleNonlinearSolveReverseDiffExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using NonlinearSolveBase: ImmutableNonlinearProblem
5+
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6+
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake
7+
8+
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
9+
10+
for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
11+
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
12+
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
13+
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
14+
prob::$(pType), sensealg, u0::$(uT), u0_changed,
15+
p::$(pT), p_changed, alg, args...; kwargs...)
16+
return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up,
17+
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
18+
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
19+
end
20+
end
21+
22+
@eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
23+
tprob::$(pType), sensealg, tu0, u0_changed,
24+
tp, p_changed, alg, args...; kwargs...)
25+
u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp)
26+
prob = remake(tprob; u0, p)
27+
out, ∇internal = solve_adjoint(
28+
prob, sensealg, u0, p, ReverseDiffOriginator(), alg, args...; kwargs...)
29+
30+
function ∇simplenonlinearsolve_solve_up...)
31+
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal...)
32+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
33+
end
34+
end
35+
end
36+
37+
end
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
module SimpleNonlinearSolveTrackerExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using NonlinearSolveBase: ImmutableNonlinearProblem
5+
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
6+
using Tracker: Tracker, TrackedArray, TrackedReal
7+
8+
using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
9+
10+
for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
11+
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
12+
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
13+
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
14+
prob::$(pType), sensealg, u0::$(uT), u0_changed,
15+
p::$(pT), p_changed, alg, args...; kwargs...)
16+
return Tracker.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up, prob,
17+
sensealg, ArrayInterface.aos_to_soa(u0), true,
18+
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
19+
end
20+
end
21+
22+
@eval Tracker.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
23+
tprob::$(pType), sensealg, tu0, u0_changed,
24+
tp, p_changed, alg, args...; kwargs...)
25+
u0, p = Tracker.data(tu0), Tracker.data(tp)
26+
prob = remake(tprob; u0, p)
27+
out, ∇internal = solve_adjoint(
28+
prob, sensealg, u0, p, TrackerOriginator(), alg, args...; kwargs...)
29+
30+
function ∇simplenonlinearsolve_solve_up(Δ)
31+
∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Tracker.data(Δ))
32+
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
33+
end
34+
end
35+
end
36+
37+
end

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,53 @@ module SimpleNonlinearSolve
22

33
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
44
AutoPolyesterForwardDiff
5+
using CommonSolve: CommonSolve, solve
56
using PrecompileTools: @compile_workload, @setup_workload
67
using Reexport: @reexport
78
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
9+
using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
810

911
using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
12+
using NonlinearSolveBase: ImmutableNonlinearProblem
13+
14+
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
15+
16+
is_extension_loaded(::Val) = false
17+
18+
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
19+
function CommonSolve.solve(prob::NonlinearProblem,
20+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
21+
prob = convert(ImmutableNonlinearProblem, prob)
22+
return solve(prob, alg, args...; kwargs...)
23+
end
24+
25+
function CommonSolve.solve(
26+
prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
27+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
28+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
29+
sensealg = prob.kwargs[:sensealg]
30+
end
31+
new_u0 = u0 !== nothing ? u0 : prob.u0
32+
new_p = p !== nothing ? p : prob.p
33+
return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
34+
p === nothing, alg, args...; prob.kwargs..., kwargs...)
35+
end
36+
37+
function simplenonlinearsolve_solve_up(
38+
prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed,
39+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
40+
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
41+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
42+
end
43+
44+
# NOTE: This is defined like this so that we don't have to keep have 2 args for the
45+
# extensions
46+
function solve_adjoint(args...; kws...)
47+
is_extension_loaded(Val(:DiffEqBase)) && return solve_adjoint_internal(args...; kws...)
48+
error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.")
49+
end
50+
51+
function solve_adjoint_internal end
1052

1153
@setup_workload begin
1254
@compile_workload begin end

0 commit comments

Comments
 (0)