-
-
Notifications
You must be signed in to change notification settings - Fork 53
Implement Adjoints for solution of IntervalNonlinearProblems #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ChrisRackauckas
merged 25 commits into
SciML:master
from
jClugstor:intervalnonlinearsolve_adjoints
May 22, 2025
Merged
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
b5e2ce2
make solve algorithms use __solve
jClugstor 5205bdf
add bracketingnonlinear_solve_up
jClugstor ec41473
add extensions
jClugstor ef32112
add Bracketing ChainRulesCoreExt
jClugstor 1610a51
better error message to make sure problem constructor adjoints exist
jClugstor 845ec7f
add weakdeps
jClugstor b8eb03f
add test
jClugstor a878e8e
fix test
jClugstor db469c1
use SciMLBase instead
jClugstor d52842f
use gradient, p might not be scalar
jClugstor ff43257
add zygote as trigger for chainrulescore extension
jClugstor 634b1c3
account for both derivative and gradient
jClugstor f593cc4
old docstring
jClugstor 53cd09a
add ForwardDiff trigger, more using
jClugstor bac45ad
get rid of unnecessary Zygote
jClugstor 77b9e5a
fix adjoint test
jClugstor 0ed1191
don't need diffeqbase ext stuff
jClugstor 7d48d7d
load bracketing nonlinear solve in test
jClugstor f458c96
fix project.toml
jClugstor 50ce860
add Zygote to test deps
jClugstor d70f9b3
test should use Bisection
jClugstor 7f3db45
account for Thunks, non tangent types
jClugstor 671d23a
fix test
jClugstor 789e04b
make imports explicit, add ompat bounds
jClugstor d8b82af
Update lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChain…
ChrisRackauckas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
lib/BracketingNonlinearSolve/ext/BracketingNonlinearSolveChainRulesCoreExt.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
module BracketingNonlinearSolveChainRulesCoreExt | ||
|
||
using CommonSolve: CommonSolve, solve | ||
using ForwardDiff: ForwardDiff | ||
using SciMLBase: SciMLBase, IntervalNonlinearProblem | ||
using ChainRulesCore: ChainRulesCore, AbstractThunk, NoTangent, Tangent, unthunk | ||
|
||
using BracketingNonlinearSolve: bracketingnonlinear_solve_up | ||
|
||
function ChainRulesCore.rrule( | ||
::typeof(bracketingnonlinear_solve_up), | ||
prob::IntervalNonlinearProblem, | ||
sensealg, p, alg, args...; kwargs... | ||
) | ||
out = solve(prob, alg) | ||
u = out.u | ||
f = SciMLBase.unwrapped_f(prob.f) | ||
function ∇bracketingnonlinear_solve_up(Δ) | ||
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ | ||
# Δ = dg/du | ||
Δ isa Tangent ? delu = Δ.u : delu = Δ | ||
λ = only(ForwardDiff.derivative(u -> f(u, p), only(u)) \ delu) | ||
if p isa Number | ||
dgdp = -λ * ForwardDiff.derivative(p -> f(u, p), p) | ||
else | ||
dgdp = -λ * ForwardDiff.gradient(p -> f(u, p), p) | ||
end | ||
return (NoTangent(), NoTangent(), NoTangent(), | ||
dgdp, NoTangent(), | ||
ntuple(_ -> NoTangent(), length(args))...) | ||
end | ||
return out, ∇bracketingnonlinear_solve_up | ||
end | ||
|
||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
@testitem "Simple Adjoint Test" tags=[:adjoint] begin | ||
using ForwardDiff, Zygote, BracketingNonlinearSolve | ||
|
||
ff(u, p) = u^2 .- p[1] | ||
|
||
function solve_nlprob(p) | ||
prob = IntervalNonlinearProblem{false}(ff, (1.0, 3.0), p) | ||
sol = solve(prob, Bisection()) | ||
res = sol isa AbstractArray ? sol : sol.u | ||
return sum(abs2, res) | ||
end | ||
|
||
p = [2.0, 2.0] | ||
|
||
∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) | ||
∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) | ||
@test ∂p_zygote ≈ ∂p_forwarddiff | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.