Skip to content

Commit 9f3b9d9

Browse files
committed
Revert "Fix Enzyme solve_up rule signature to support DuplicatedNoNeed"
This reverts commit fa5e8c6.
1 parent 64f65be commit 9f3b9d9

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ module DiffEqBaseEnzymeExt
99

1010
function Enzyme.EnzymeRules.augmented_primal(
1111
config::Enzyme.EnzymeRules.RevConfigWidth{1},
12-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, prob,
12+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
1313
sensealg::Union{
1414
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
15-
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
15+
u0, p, args...; kwargs...) where {RT}
1616
@inline function copy_or_reuse(val, idx)
1717
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
1818
return deepcopy(val)
@@ -31,18 +31,18 @@ module DiffEqBaseEnzymeExt
3131
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
3232
kwargs...)
3333

34-
ResType = typeof(res[1])
35-
dres = Enzyme.make_zero(res[1])::ResType
34+
dres = Enzyme.make_zero(res[1])::RT
3635
tup = (dres, res[2])
37-
return Enzyme.EnzymeRules.AugmentedReturn{ResType, ResType, Any}(res[1], dres, tup::Any)
36+
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
3837
end
3938

4039
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
41-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{RT}, tape, prob,
40+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
4241
sensealg::Union{
4342
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
44-
u0, p, args...; kwargs...) where {RT <: Union{Duplicated, DuplicatedNoNeed}}
43+
u0, p, args...; kwargs...) where {RT}
4544
dres, clos = tape
45+
dres = dres::RT
4646
dargs = clos(dres)
4747
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
4848
if ptr isa Enzyme.Const

0 commit comments

Comments
 (0)