Skip to content

Commit e761566

Browse files
authored
Merge branch 'master' into scimlops
2 parents ecd003d + 815a1e5 commit e761566

13 files changed

+327
-24
lines changed

Project.toml

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.114.2"
4+
version = "6.115.3"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
@@ -37,15 +37,21 @@ GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
3737
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
3838
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3939
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
40+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
41+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4042
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
43+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4144

4245
[extensions]
43-
DistributionsExt = "Distributions"
44-
MeasurementsExt = "Measurements"
45-
MPIExt = "MPI"
46-
MonteCarloMeasurementsExt = "MonteCarloMeasurements"
47-
GeneralizedGeneratedExt = "GeneralizedGenerated"
48-
UnitfulExt = "Unitful"
46+
DiffEqBaseZygoteExt = "Zygote"
47+
DiffEqBaseReverseDiffExt = "ReverseDiff"
48+
DiffEqBaseTrackerExt = "Tracker"
49+
DiffEqBaseDistributionsExt = "Distributions"
50+
DiffEqBaseMeasurementsExt = "Measurements"
51+
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
52+
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
53+
DiffEqBaseUnitfulExt = "Unitful"
54+
DiffEqBaseMPIExt = "MPI"
4955

5056
[compat]
5157
ArrayInterfaceCore = "0.1.26"
@@ -63,7 +69,7 @@ PreallocationTools = "0.4"
6369
RecursiveArrayTools = "2"
6470
Reexport = "1.0"
6571
Requires = "1.0"
66-
SciMLBase = "1.82"
72+
SciMLBase = "1.84"
6773
Setfield = "0.8, 1"
6874
Static = "0.7, 0.8"
6975
StaticArrays = "1.0"
@@ -83,10 +89,13 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
8389
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8490
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
8591
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
92+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8693
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8794
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
8895
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
96+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8997
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
98+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9099

91100
[targets]
92101
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"]

ext/DistributionsExt.jl renamed to ext/DiffEqBaseDistributionsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module DistributionsExt
1+
module DiffEqBaseDistributionsExt
22

33
using Distributions, DiffEqBase
44

ext/GeneralizedGeneratedExt.jl renamed to ext/DiffEqBaseGeneralizedGeneratedExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GeneralizedGeneratedExt
1+
module DiffEqBaseGeneralizedGeneratedExt
22

33
using DiffEqBase
44
isdefined(Base, :get_extension) ? (using GeneralizedGenerated) :

ext/MPI.jl renamed to ext/DiffEqBaseMPIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MPIExt
1+
module DiffEqBaseMPIExt
22

33
import DiffEqBase
44
isdefined(Base, :get_extension) ? (import MPI) : (import ..MPI)

ext/MeasurementsExt.jl renamed to ext/DiffEqBaseMeasurementsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MeasurementsExt
1+
module DiffEqBaseMeasurementsExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/MonteCarloMeasurementsExt.jl renamed to ext/DiffEqBaseMonteCarloMeasurementsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module MonteCarloMeasurementsExt
1+
module DiffEqBaseMonteCarloMeasurementsExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
module DiffEqBaseReverseDiffExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff)
6+
7+
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
8+
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
9+
10+
DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
11+
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
12+
p::ReverseDiff.TrackedArray, t0)
13+
u0
14+
end
15+
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
16+
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
17+
u0
18+
end
19+
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
20+
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
21+
u0
22+
end
23+
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
24+
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
25+
26+
# Support adaptive with non-tracked time
27+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
28+
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
29+
end
30+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
31+
t) where {N}
32+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
33+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
34+
end
35+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
36+
t) where {N}
37+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
38+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
39+
end
40+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, t)
41+
abs(DiffEqBase.value(u))
42+
end
43+
44+
# Support TrackedReal time, don't drop tracking on the adaptivity there
45+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray,
46+
t::ReverseDiff.TrackedReal)
47+
sqrt(sum(abs2, u) / length(u))
48+
end
49+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
50+
t::ReverseDiff.TrackedReal) where {N}
51+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
52+
length(u))
53+
end
54+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
55+
t::ReverseDiff.TrackedReal) where {N}
56+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
57+
length(u))
58+
end
59+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal,
60+
t::ReverseDiff.TrackedReal)
61+
abs(u)
62+
end
63+
64+
# `ReverseDiff.TrackedArray`
65+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
66+
sensealg::Union{
67+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
68+
Nothing}, u0::ReverseDiff.TrackedArray,
69+
p::ReverseDiff.TrackedArray, args...; kwargs...)
70+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
71+
end
72+
73+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
74+
sensealg::Union{
75+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
76+
Nothing}, u0, p::ReverseDiff.TrackedArray,
77+
args...; kwargs...)
78+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
79+
end
80+
81+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
82+
sensealg::Union{
83+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
84+
Nothing}, u0::ReverseDiff.TrackedArray, p,
85+
args...; kwargs...)
86+
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
87+
end
88+
89+
# `AbstractArray{<:ReverseDiff.TrackedReal}`
90+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
91+
sensealg::Union{
92+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
93+
Nothing},
94+
u0::AbstractArray{<:ReverseDiff.TrackedReal},
95+
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
96+
kwargs...)
97+
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
98+
kwargs...)
99+
end
100+
101+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
102+
sensealg::Union{
103+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
104+
Nothing}, u0,
105+
p::AbstractArray{<:ReverseDiff.TrackedReal},
106+
args...; kwargs...)
107+
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
108+
end
109+
110+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
111+
sensealg::Union{
112+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
113+
Nothing},
114+
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
115+
args...; kwargs...)
116+
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
117+
end
118+
119+
# Required becase ReverseDiff.@grad function DiffEqBase.solve_up is not supported!
120+
import DiffEqBase: solve_up
121+
ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...)
122+
out = DiffEqBase._solve_adjoint(prob, sensealg, ReverseDiff.value(u0),
123+
ReverseDiff.value(p),
124+
SciMLBase.ReverseDiffOriginator(), args...; kwargs...)
125+
function actual_adjoint(_args...)
126+
original_adjoint = out[2](_args...)
127+
if isempty(args) # alg is missing
128+
tuple(original_adjoint[1:4]..., original_adjoint[6:end]...)
129+
else
130+
original_adjoint
131+
end
132+
end
133+
Array(out[1]), actual_adjoint
134+
end
135+
136+
end

ext/DiffEqBaseTrackerExt.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
module DiffEqBaseTrackerExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)
6+
7+
DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
8+
DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
9+
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
10+
DiffEqBase.value(x::Tracker.TrackedArray) = x.data
11+
12+
DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
13+
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
14+
p::Tracker.TrackedArray, t0)
15+
u0
16+
end
17+
function DiffEqBase.promote_u0(u0::Tracker.TrackedArray,
18+
p::AbstractArray{<:Tracker.TrackedReal}, t0)
19+
u0
20+
end
21+
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
22+
p::AbstractArray{<:Tracker.TrackedReal}, t0)
23+
u0
24+
end
25+
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
26+
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)
27+
28+
@inline DiffEqBase.fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
29+
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))
30+
31+
# Support adaptive with non-tracked time
32+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t)
33+
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
34+
end
35+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
36+
t) where {N}
37+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
38+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
39+
end
40+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
41+
t) where {N}
42+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
43+
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
44+
end
45+
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t) = abs(DiffEqBase.value(u))
46+
47+
# Support TrackedReal time, don't drop tracking on the adaptivity there
48+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray,
49+
t::Tracker.TrackedReal)
50+
sqrt(sum(abs2, u) / length(u))
51+
end
52+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N},
53+
t::Tracker.TrackedReal) where {N}
54+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
55+
length(u))
56+
end
57+
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N},
58+
t::Tracker.TrackedReal) where {N}
59+
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
60+
length(u))
61+
end
62+
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)
63+
64+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
65+
sensealg::Union{
66+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
67+
Nothing}, u0::Tracker.TrackedArray,
68+
p::Tracker.TrackedArray, args...; kwargs...)
69+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
70+
end
71+
72+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
73+
sensealg::Union{
74+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
75+
Nothing}, u0::Tracker.TrackedArray, p, args...;
76+
kwargs...)
77+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
78+
end
79+
80+
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
81+
sensealg::Union{
82+
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
83+
Nothing}, u0, p::Tracker.TrackedArray, args...;
84+
kwargs...)
85+
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
86+
end
87+
88+
Tracker.@grad function DiffEqBase.solve_up(prob,
89+
sensealg::Union{Nothing,
90+
SciMLBase.AbstractOverloadingSensitivityAlgorithm
91+
},
92+
u0, p, args...;
93+
kwargs...)
94+
DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p),
95+
SciMLBase.TrackerOriginator(), args...; kwargs...)
96+
end
97+
98+
end

ext/UnitfulExt.jl renamed to ext/DiffEqBaseUnitfulExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module UnitfulExt
1+
module DiffEqBaseUnitfulExt
22

33
using DiffEqBase
44
import DiffEqBase: value

ext/DiffEqBaseZygoteExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module DiffEqBaseZygoteExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import Zygote) : (import ..Zygote)
6+
7+
function ∇tmap(cx, f, args...)
8+
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
9+
if isempty(ys_and_backs)
10+
ys_and_backs, _ -> (NoTangent(), NoTangent())
11+
else
12+
ys, backs = Zygote.unzip(ys_and_backs)
13+
function ∇tmap_internal(Δ)
14+
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
15+
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
16+
Δf = reduce(Zygote.accum, Δf_and_args[1])
17+
(Δf, Δf_and_args[2:end]...)
18+
end
19+
ys, ∇tmap_internal
20+
end
21+
end
22+
23+
function ∇responsible_map(cx, f, args...)
24+
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
25+
args...)
26+
if isempty(ys_and_backs)
27+
ys_and_backs, _ -> (NoTangent(), NoTangent())
28+
else
29+
ys, backs = Zygote.unzip(ys_and_backs)
30+
ys,
31+
function ∇responsible_map_internal(Δ)
32+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
33+
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
34+
Zygote._tryreverse(SciMLBase.responsible_map,
35+
backs, Δ)...)
36+
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
37+
Δf_and_args_zipped))
38+
Δf = reduce(Zygote.accum, Δf_and_args[1])
39+
(Δf, Δf_and_args[2:end]...)
40+
end
41+
end
42+
end
43+
44+
Zygote.@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
45+
∇tmap(__context__, f, args...)
46+
end
47+
48+
Zygote.@adjoint function SciMLBase.responsible_map(f,
49+
args::Union{AbstractArray, Tuple
50+
}...)
51+
∇responsible_map(__context__, f, args...)
52+
end
53+
54+
end

0 commit comments

Comments
 (0)