Skip to content

Commit fa5e8c6

Browse files
Fix Enzyme solve_up rule signature to support DuplicatedNoNeed
Fixes SciML/SciMLSensitivity.jl#1225 The Enzyme rules for solve_up were failing when sensealg was passed via the ODEProblem constructor instead of solve(). The issue was that the return type annotation was restricted to Duplicated{RT}, but Enzyme can also use DuplicatedNoNeed{RT}. Changes: - Updated augmented_primal signature to accept both Duplicated and DuplicatedNoNeed - Updated reverse signature to accept both Duplicated and DuplicatedNoNeed - Fixed type handling to extract the inner result type instead of using the annotation type - Removed incorrect type assertion in reverse function 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3667bdb commit fa5e8c6

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt
99

1010
function Enzyme.EnzymeRules.augmented_primal(
1111
config::Enzyme.EnzymeRules.RevConfigWidth{1},
12-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
12+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, prob,
1313
sensealg::Union{
1414
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
15-
u0, p, args...; kwargs...) where {RT}
15+
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
1616
@inline function copy_or_reuse(val, idx)
1717
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
1818
return deepcopy(val)
@@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt
3131
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
3232
kwargs...)
3333

34-
dres = Enzyme.make_zero(res[1])::RT
34+
ResType = typeof(res[1])
35+
dres = Enzyme.make_zero(res[1])::ResType
3536
tup = (dres, res[2])
36-
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
37+
return Enzyme.EnzymeRules.AugmentedReturn{ResType, ResType, Any}(res[1], dres, tup::Any)
3738
end
3839

3940
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
40-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
41+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, tape, prob,
4142
sensealg::Union{
4243
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
43-
u0, p, args...; kwargs...) where {RT}
44+
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
4445
dres, clos = tape
45-
dres = dres::RT
4646
dargs = clos(dres)
4747
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
4848
if ptr isa Enzyme.Const

0 commit comments

Comments
 (0)