Skip to content

Commit 1fa3616

Browse files
Merge pull request #1066 from AayushSabharwal/as/immutable-getproperty-adjoint
fix: add `rrule` for `getproperty(::NonlinearProblem` to avoid mutable accumulation
2 parents 39d3c8a + 82238a2 commit 1fa3616

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,21 @@ function ChainRulesCore.rrule(
137137
SciMLBase.IntervalNonlinearProblem(args...; kwargs...), IntervalNonlinearProblemAdjoint
138138
end
139139

140+
# This is a workaround for the fact `NonlinearProblem` is a mutable struct. In SciMLSensitivity, we call
141+
# `back` explicitly while already in a reverse pass causing a nested gradient call. The mutable struct
142+
# causes accumulation anytime `getfield/property` is called, accumulating multiple times. This tries to treat
143+
# AbstractDEProblem as immutable for the purposes of reverse mode AD.
144+
function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, ::typeof(Base.getproperty), x::NonlinearProblem, f::Symbol)
145+
val = getfield(x, f)
146+
function back(der)
147+
dx = if der === nothing
148+
ChainRulesCore.zero_tangent(x)
149+
else
150+
NamedTuple{(f,)}((der,))
151+
end
152+
return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(x)(dx), ChainRulesCore.NoTangent())
153+
end
154+
return val, back
155+
end
156+
140157
end

0 commit comments

Comments
 (0)