Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,21 @@ function ChainRulesCore.rrule(
SciMLBase.IntervalNonlinearProblem(args...; kwargs...), IntervalNonlinearProblemAdjoint
end

# This is a workaround for the fact `NonlinearProblem` is a mutable struct. In SciMLSensitivity, we call
# `back` explicitly while already in a reverse pass causing a nested gradient call. The mutable struct
# causes accumulation anytime `getfield/property` is called, accumulating multiple times. This tries to treat
# AbstractDEProblem as immutable for the purposes of reverse mode AD.
Comment on lines +142 to +143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's only on NonlinearProblem?

function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::typeof(Base.getproperty), x::NonlinearProblem, f::Symbol)
val = getfield(x, f)
function back(der)
dx = if der === nothing
ChainRulesCore.zero_tangent(x)
else
NamedTuple{(f,)}((der,))
end
return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(x)(dx), ChainRulesCore.NoTangent())
end
return val, back
end

end
Loading