Skip to content

Commit 6e09c37

Browse files
Fix Enzyme reverse rule tangent accumulation for structured parameters
When parameters are SciMLStructure types (e.g. MTKParameters), use the SciMLStructures interface (canonicalize/replace!) to accumulate tangents into the parameter shadow instead of raw broadcasting. The tangent from SciMLSensitivity may be: - A tunable gradient vector (EnzymeOriginator path) - Another SciMLStructure - A broadcastable array (plain Vector parameters) All cases are handled through SciMLStructures.canonicalize to extract and accumulate the tunable portion. Fixes the tangent accumulation portion of #878. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1ce693f commit 6e09c37

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

lib/NonlinearSolveBase/ext/NonlinearSolveBaseEnzymeExt.jl

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,39 @@ import SciMLBase: SciMLBase, value
55
using Enzyme
66
import Enzyme: Const, MixedDuplicated
77
using ChainRulesCore
8+
import SciMLStructures
9+
10+
# Accumulate a tangent `darg` into a shadow `dval`.
11+
# When `dval` is a SciMLStructure (e.g. MTKParameters), `darg` may be:
12+
# - A tunable gradient vector (from SciMLSensitivity's EnzymeOriginator path)
13+
# - Another SciMLStructure
14+
# - A broadcastable array
15+
# In all cases, accumulation goes through the SciMLStructures interface.
16+
function _accum_tangent!(dval, darg)
17+
if SciMLStructures.isscimlstructure(dval) && !(dval isa AbstractArray)
18+
shadow_tunables, _, _ = SciMLStructures.canonicalize(
19+
SciMLStructures.Tunable(), dval,
20+
)
21+
if SciMLStructures.isscimlstructure(darg)
22+
darg_tunables, _, _ = SciMLStructures.canonicalize(
23+
SciMLStructures.Tunable(), darg,
24+
)
25+
shadow_tunables .+= darg_tunables
26+
elseif darg isa AbstractVector && length(darg) == length(shadow_tunables)
27+
# Tunable gradient vector (returned by SciMLSensitivity for
28+
# EnzymeOriginator when p is a SciMLStructure)
29+
shadow_tunables .+= darg
30+
else
31+
# Fallback: try direct broadcast (may error for incompatible types)
32+
dval .+= darg
33+
return nothing
34+
end
35+
SciMLStructures.replace!(SciMLStructures.Tunable(), dval, shadow_tunables)
36+
else
37+
dval .+= darg
38+
end
39+
return nothing
40+
end
841

942
function Enzyme.EnzymeRules.augmented_primal(
1043
config::Enzyme.EnzymeRules.RevConfigWidth{1},
@@ -57,9 +90,9 @@ function Enzyme.EnzymeRules.reverse(
5790
continue
5891
end
5992
if ptr isa MixedDuplicated
60-
ptr.dval[] .+= darg
93+
_accum_tangent!(ptr.dval[], darg)
6194
else
62-
ptr.dval .+= darg
95+
_accum_tangent!(ptr.dval, darg)
6396
end
6497
end
6598
Enzyme.make_zero!(dres.u)

0 commit comments

Comments
 (0)