Skip to content

Commit 8704130

Browse files
Merge pull request #871 from SciML/weakdep
Change weak dep naming scheme and add the AD weak dep overloads
2 parents 7e84b9c + 9ea338b commit 8704130

13 files changed

+332
-22
lines changed

Project.toml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,21 @@ GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
3737
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
3838
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3939
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
40+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
41+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4042
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
43+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4144

4245
[extensions]
43-
DistributionsExt = "Distributions"
44-
MeasurementsExt = "Measurements"
45-
MPIExt = "MPI"
46-
MonteCarloMeasurementsExt = "MonteCarloMeasurements"
47-
GeneralizedGeneratedExt = "GeneralizedGenerated"
48-
UnitfulExt = "Unitful"
46+
DiffEqBaseZygoteExt = "Zygote"
47+
DiffEqBaseReverseDiffExt = "ReverseDiff"
48+
DiffEqBaseTrackerExt = "Tracker"
49+
DiffEqBaseDistributionsExt = "Distributions"
50+
DiffEqBaseMeasurementsExt = "Measurements"
51+
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
52+
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
53+
DiffEqBaseUnitfulExt = "Unitful"
54+
DiffEqBaseMPIExt = "MPI"
4955

5056
[compat]
5157
ArrayInterfaceCore = "0.1.26"
@@ -83,10 +89,13 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
8389
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8490
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
8591
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
92+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8693
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8794
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
8895
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
96+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8997
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
98+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9099

91100
[targets]
92101
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"]

ext/DistributionsExt.jl renamed to ext/DiffEqBaseDistributionsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module DistributionsExt
1+
module DiffEqBaseDistributionsExt
22

33
using Distributions, DiffEqBase
44

ext/GeneralizedGeneratedExt.jl renamed to ext/DiffEqBaseGeneralizedGeneratedExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GeneralizedGeneratedExt
1+
module DiffEqBaseGeneralizedGeneratedExt
22

33
using DiffEqBase
44
isdefined(Base, :get_extension) ? (using GeneralizedGenerated) :

ext/MPI.jl renamed to ext/DiffEqBaseMPIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MPIExt
1+
module DiffEqBaseMPIExt
22

33
import DiffEqBase
44
isdefined(Base, :get_extension) ? (import MPI) : (import ..MPI)

ext/MeasurementsExt.jl renamed to ext/DiffEqBaseMeasurementsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MeasurementsExt
1+
module DiffEqBaseMeasurementsExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/MonteCarloMeasurementsExt.jl renamed to ext/DiffEqBaseMonteCarloMeasurementsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MonteCarloMeasurementsExt
1+
module DiffEqBaseMonteCarloMeasurementsExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
module DiffEqBaseReverseDiffExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff)
6+
7+
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
8+
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
9+
10+
DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
11+
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
12+
p::ReverseDiff.TrackedArray, t0)
13+
u0
14+
end
15+
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
16+
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
17+
u0
18+
end
19+
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
20+
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
21+
u0
22+
end
23+
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
24+
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
25+
26+
# Support adaptive with non-tracked time
27+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
28+
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
29+
end
30+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
31+
t) where {N}
32+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
33+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
34+
end
35+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
36+
t) where {N}
37+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
38+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
39+
end
40+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t)
41+
abs(DiffEqBase.value(u))
42+
end
43+
44+
# Support TrackedReal time, don't drop tracking on the adaptivity there
45+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray,
46+
t::ReverseDiff.TrackedReal)
47+
sqrt(sum(abs2, u) / length(u))
48+
end
49+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
50+
t::ReverseDiff.TrackedReal) where {N}
51+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
52+
length(u))
53+
end
54+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
55+
t::ReverseDiff.TrackedReal) where {N}
56+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
57+
length(u))
58+
end
59+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal,
60+
t::ReverseDiff.TrackedReal)
61+
abs(u)
62+
end
63+
64+
# `ReverseDiff.TrackedArray`
65+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
66+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
67+
Nothing}, u0::ReverseDiff.TrackedArray,
68+
p::ReverseDiff.TrackedArray, args...; kwargs...)
69+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
70+
end
71+
72+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
73+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
74+
Nothing}, u0, p::ReverseDiff.TrackedArray,
75+
args...; kwargs...)
76+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
77+
end
78+
79+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
80+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
81+
Nothing}, u0::ReverseDiff.TrackedArray, p,
82+
args...; kwargs...)
83+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
84+
end
85+
86+
# `AbstractArray{<:ReverseDiff.TrackedReal}`
87+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
88+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
89+
Nothing},
90+
u0::AbstractArray{<:ReverseDiff.TrackedReal},
91+
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
92+
kwargs...)
93+
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
94+
kwargs...)
95+
end
96+
97+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
98+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
99+
Nothing}, u0,
100+
p::AbstractArray{<:ReverseDiff.TrackedReal},
101+
args...; kwargs...)
102+
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
103+
end
104+
105+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
106+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
107+
Nothing},
108+
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
109+
args...; kwargs...)
110+
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
111+
end
112+
113+
@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,
114+
proto::ReverseDiff.TrackedArray)
115+
ReverseDiff.track(convert.(eltype(proto.value), randn(rng, size(proto))))
116+
end
117+
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
118+
rand_vec::Array{<:ReverseDiff.TrackedReal
119+
})
120+
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
121+
end
122+
@inline function DiffEqNoiseProcess.wiener_randn!(rng::AbstractRNG,
123+
rand_vec::AbstractArray{
124+
<:ReverseDiff.TrackedReal
125+
})
126+
rand_vec .= ReverseDiff.track.(randn.((rng,), typeof.(DiffEqBase.value.(rand_vec))))
127+
end
128+
129+
# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!
130+
import DiffEqBase: solve_up
131+
ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...)
132+
out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
133+
ReverseDiff.value(p),
134+
SciMLBase.ReverseDiffOriginator(), args...; kwargs...)
135+
function actual_adjoint(_args...)
136+
original_adjoint = out[2](_args...)
137+
if isempty(args) # alg is missing
138+
tuple(original_adjoint[1:4]..., original_adjoint[6:end]...)
139+
else
140+
original_adjoint
141+
end
142+
end
143+
Array(out[1]), actual_adjoint
144+
end
145+
146+
end

ext/DiffEqBaseTrackerExt.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
module DiffEqBaseTrackerExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)
6+
7+
DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
8+
DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
9+
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
10+
DiffEqBase.value(x::Tracker.TrackedArray) = x.data
11+
12+
DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
13+
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
14+
p::Tracker.TrackedArray, t0)
15+
u0
16+
end
17+
function DiffEqBase.promote_u0(u0::Tracker.TrackedArray,
18+
p::AbstractArray{<:Tracker.TrackedReal}, t0)
19+
u0
20+
end
21+
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
22+
p::AbstractArray{<:Tracker.TrackedReal}, t0)
23+
u0
24+
end
25+
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
26+
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)
27+
28+
@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
29+
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))
30+
31+
# Support adaptive with non-tracked time
32+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t)
33+
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
34+
end
35+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
36+
t) where {N}
37+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
38+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
39+
end
40+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
41+
t) where {N}
42+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
43+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
44+
end
45+
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u))
46+
47+
# Support TrackedReal time, don't drop tracking on the adaptivity there
48+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray,
49+
t::Tracker.TrackedReal)
50+
sqrt(sum(abs2, u) / length(u))
51+
end
52+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
53+
t::Tracker.TrackedReal) where {N}
54+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
55+
length(u))
56+
end
57+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
58+
t::Tracker.TrackedReal) where {N}
59+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
60+
length(u))
61+
end
62+
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)
63+
64+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
65+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
66+
Nothing}, u0::Tracker.TrackedArray,
67+
p::Tracker.TrackedArray, args...; kwargs...)
68+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
69+
end
70+
71+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
72+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
73+
Nothing}, u0::Tracker.TrackedArray, p, args...;
74+
kwargs...)
75+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
76+
end
77+
78+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
79+
sensealg::Union{DiffEqBase.AbstractOverloadingSensitivityAlgorithm,
80+
Nothing}, u0, p::Tracker.TrackedArray, args...;
81+
kwargs...)
82+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
83+
end
84+
85+
Tracker.@grad function DiffEqBase.solve_up(prob,
86+
sensealg::Union{Nothing,
87+
DiffEqBase.AbstractOverloadingSensitivityAlgorithm
88+
},
89+
u0, p, args...;
90+
kwargs...)
91+
DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p),
92+
SciMLBase.TrackerOriginator(), args...; kwargs...)
93+
end
94+
95+
end

ext/UnitfulExt.jl renamed to ext/DiffEqBaseUnitfulExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module UnitfulExt
1+
module DiffEqBaseUnitfulExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/DiffEqBaseZygoteExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module DiffEqBaseZygoteExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import Zygote) : (import ..Zygote)
6+
7+
function ∇tmap(cx, f, args...)
8+
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
9+
if isempty(ys_and_backs)
10+
ys_and_backs, _ -> (NoTangent(), NoTangent())
11+
else
12+
ys, backs = Zygote.unzip(ys_and_backs)
13+
function ∇tmap_internal(Δ)
14+
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
15+
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
16+
Δf = reduce(Zygote.accum, Δf_and_args[1])
17+
(Δf, Δf_and_args[2:end]...)
18+
end
19+
ys, ∇tmap_internal
20+
end
21+
end
22+
23+
function ∇responsible_map(cx, f, args...)
24+
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
25+
args...)
26+
if isempty(ys_and_backs)
27+
ys_and_backs, _ -> (NoTangent(), NoTangent())
28+
else
29+
ys, backs = Zygote.unzip(ys_and_backs)
30+
ys,
31+
function ∇responsible_map_internal(Δ)
32+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
33+
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
34+
Zygote._tryreverse(SciMLBase.responsible_map,
35+
backs, Δ)...)
36+
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
37+
Δf_and_args_zipped))
38+
Δf = reduce(Zygote.accum, Δf_and_args[1])
39+
(Δf, Δf_and_args[2:end]...)
40+
end
41+
end
42+
end
43+
44+
ZygoteRules.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
45+
∇tmap(__context__, f, args...)
46+
end
47+
48+
ZygoteRules.@adjoint function SciMLBase.responsible_map(f,
49+
args::Union{AbstractArray, Tuple
50+
}...)
51+
∇responsible_map(__context__, f, args...)
52+
end
53+
54+
end

0 commit comments

Comments
 (0)