Skip to content

Commit 5abb52b

Browse files
Merge pull request #957 from SciML/enzyme
Fix enzyme extension
2 parents 4b37fb1 + 719ca7b commit 5abb52b

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ jobs:
2222
- Downstream2
2323
version:
2424
- '1'
25-
- '1.6'
2625
steps:
2726
- uses: actions/checkout@v4
2827
- uses: julia-actions/setup-julia@v1

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
julia-version: [1,1.6]
21+
julia-version: [1]
2222
os: [ubuntu-latest]
2323
package:
2424
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4646
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4747

4848
[extensions]
49+
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
4950
DiffEqBaseDistributionsExt = "Distributions"
50-
DiffEqBaseEnzymeExt = "Enzyme"
51+
DiffEqBaseEnzymeExt = ["ChainRulesCore","Enzyme"]
5152
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
5253
DiffEqBaseMPIExt = "MPI"
5354
DiffEqBaseMeasurementsExt = "Measurements"

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
module DiffEqBaseChainRulesCoreExt
22

33
using DiffEqBase
4-
import DiffEqBase: numargs
4+
using DiffEqBase.SciMLBase
5+
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem
56

67
import ChainRulesCore
78
import ChainRulesCore: NoTangent
89

910
ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent()))
10-
ChainRulesCore.@non_differentiable checkkwargs(kwargshandle)
11+
ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle)
1112

12-
function ChainRulesCore.frule(::typeof(solve_up), prob,
13+
function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
1314
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
1415
u0, p, args...;
1516
kwargs...)
16-
_solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
17+
DiffEqBase._solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
1718
kwargs...)
1819
end
1920

20-
function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem,
21+
function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem,
2122
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
2223
u0, p, args...;
2324
kwargs...)
24-
_solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
25+
DiffEqBase._solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
2526
kwargs...)
2627
end
2728

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ module DiffEqBaseEnzymeExt
22

33
using DiffEqBase
44
import DiffEqBase: value
5-
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
6-
5+
using Enzyme
6+
import Enzyme: Const
77
using ChainRulesCore
8-
using EnzymeCore
98

10-
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
9+
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
1110
@inline function copy_or_reuse(val, idx)
12-
if EnzymeCore.EnzymeRules.overwritten(config)[idx] && ismutable(val)
11+
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
1312
return deepcopy(val)
1413
else
1514
return val
@@ -28,15 +27,15 @@ function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.
2827
v.= 0
2928
end
3029
tup = (dres, res[2])
31-
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
30+
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
3231
end
3332

34-
function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{<:Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
33+
function Enzyme.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{<:Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
3534
dres, clos = tape
3635
dres = dres::RT
3736
dargs = clos(dres)
3837
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
39-
if ptr isa EnzymeCore.Const
38+
if ptr isa Enzyme.Const
4039
continue
4140
end
4241
if darg == ChainRulesCore.NoTangent()

0 commit comments

Comments
 (0)