Skip to content

Commit 36a9f6d

Browse files
jClugstorChrisRackauckas
authored andcommitted
add Bracketing ChainRulesCoreExt
1 parent 9bff381 commit 36a9f6d

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
module BracketingNonlinearSolveChainRulesCoreExt
2+
3+
using CommonSolve: CommonSolve
4+
using ForwardDiff: ForwardDiff
5+
using DiffEqBase
6+
7+
using BracketingNonlinearSolve: bracketingnonlinear_solve_up, is_extension_loaded
8+
9+
function ChainRulesCore.rrule(
10+
::typeof(bracketingnonlinear_solve_up),
11+
prob::IntervalNonlinearProblem,
12+
sensealg, p, alg, args...; kwargs...
13+
)
14+
# DiffEqBase is needed for problem/function constructor adjoint
15+
!is_extension_loaded(Val(:DiffEqBase)) &&
16+
error("Adjoint sensitivity analysis requires `DiffEqBase.jl` to be explicitly loaded.")
17+
out = solve(prob)
18+
u = out.u
19+
f = SciMLBase.unwrapped_f(prob.f)
20+
function ∇bracketingnonlinear_solve_up(Δ)
21+
# Δ = dg/du
22+
λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ Δ.u)
23+
dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), only(p))
24+
return (NoTangent(), NoTangent(), NoTangent(),
25+
dgdp, NoTangent(),
26+
ntuple(_ -> NoTangent(), length(args))...)
27+
end
28+
return out, ∇bracketingnonlinear_solve_up
29+
end
30+
31+
end

0 commit comments

Comments
 (0)