Skip to content

Commit aa72cf2

Browse files
committed
add the ChainRulesCore extension
1 parent 1b09e53 commit aa72cf2

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

lib/NonlinearSolveBase/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2929

3030
[weakdeps]
3131
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
32+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3233
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
3334
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3435
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
@@ -44,13 +45,15 @@ NonlinearSolveBaseLineSearchExt = "LineSearch"
4445
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
4546
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
4647
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
48+
NonlinearSolveBaseChainRulesCoreExt = "ChainRulesCore"
4749

4850
[compat]
4951
ADTypes = "1.9"
5052
Adapt = "4.1.0"
5153
Aqua = "0.8.7"
5254
ArrayInterface = "7.9"
5355
BandedMatrices = "1.5"
56+
ChainRulesCore = "1"
5457
CommonSolve = "0.2.4"
5558
Compat = "4.15"
5659
ConcreteStructs = "0.2.3"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module NonlinearSolveBaseChainRulesCoreExt
2+
3+
using NonlinearSolveBase
4+
using NonlinearSolveBase: AbstractNonlinearProblem
5+
using SciMLBase
6+
using SciMLBase: AbstractSensitivityAlgorithm
7+
8+
import ChainRulesCore
9+
import ChainRulesCore: NoTangent
10+
11+
ChainRulesCore.@non_differentiable NonlinearSolveBase.checkkwargs(kwargshandle)
12+
13+
function ChainRulesCore.frule(::typeof(NonlinearSolveBase.solve_up), prob,
14+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
15+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
16+
kwargs...)
17+
NonlinearSolveBase._solve_forward(
18+
prob, sensealg, u0, p,
19+
originator, args...;
20+
kwargs...)
21+
end
22+
23+
function ChainRulesCore.rrule(::typeof(NonlinearSolveBase.solve_up), prob::AbstractNonlinearProblem,
24+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
25+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
26+
kwargs...)
27+
NonlinearSolveBase._solve_adjoint(
28+
prob, sensealg, u0, p,
29+
originator, args...;
30+
kwargs...)
31+
end
32+
33+
end

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,13 @@ NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.valu
195195
@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x))
196196
@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
197197

198+
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
199+
isdualtype(::Type{<:ForwardDiff.Dual}) = true
200+
201+
function anyeltypedual(
202+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
203+
::Type{Val{counter}} = Val{0}) where {counter}
204+
anyeltypedual((prob.u0, prob.p))
205+
end
206+
198207
end

0 commit comments

Comments
 (0)