Skip to content

Commit dcfe6eb

Browse files
Merge pull request #517 from SciML/tracker_rd
Tracker and ReverseDiff custom concrete_solve gradients
2 parents 200a6dc + 562205d commit dcfe6eb

File tree

5 files changed

+80
-35
lines changed

5 files changed

+80
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.34.3"
4+
version = "6.35.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/init.jl

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -145,37 +145,7 @@ function __init__()
145145
end
146146

147147
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
148-
value(x::Type{Tracker.TrackedReal{T}}) where T = T
149-
value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N}
150-
value(x::Tracker.TrackedReal) = x.data
151-
value(x::Tracker.TrackedArray) = x.data
152-
153-
@inline fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
154-
@inline Base.any(f::Function,x::Tracker.TrackedArray) = any(f,Tracker.data(x))
155-
156-
# Support adaptive with non-tracked time
157-
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t) where {N}
158-
sqrt(sum(abs2,value(u)) / length(u))
159-
end
160-
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t) where {N}
161-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
162-
end
163-
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t) where {N}
164-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
165-
end
166-
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t) = abs(value(u))
167-
168-
# Support TrackedReal time, don't drop tracking on the adaptivity there
169-
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t::Tracker.TrackedReal) where {N}
170-
sqrt(sum(abs2,u) / length(u))
171-
end
172-
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
173-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
174-
end
175-
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
176-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
177-
end
178-
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
148+
include("tracker.jl")
179149
end
180150

181151
# Piracy, should get upstreamed
@@ -213,6 +183,10 @@ function __init__()
213183
end
214184
end
215185

186+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
187+
include("reversediff.jl")
188+
end
189+
216190
@require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin
217191
numargs(::GeneralizedGenerated.RuntimeFn{Args}) where Args = GeneralizedGenerated.from_type(Args) |> length
218192
end

src/reversediff.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
2+
sensealg=nothing,kwargs...)
3+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
4+
end
5+
6+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
7+
sensealg=nothing,kwargs...)
8+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
9+
end
10+
11+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
12+
sensealg=nothing,kwargs...)
13+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
14+
end
15+
16+
ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
17+
sensealg=nothing,kwargs...)
18+
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
19+
Array(out[1]),out[2]
20+
end

src/solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,18 +220,18 @@ end
220220

221221
function _concrete_solve end
222222

223-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,
223+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
224224
u0=prob.u0,p=prob.p,args...;kwargs...)
225225
_concrete_solve(prob,alg,u0,p,args...;kwargs...)
226226
end
227227

228-
function _concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,
228+
function _concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
229229
u0=prob.u0,p=prob.p,args...;kwargs...)
230230
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
231231
RecursiveArrayTools.DiffEqArray(sol.u,sol.t)
232232
end
233233

234-
function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::DiffEqBase.DEAlgorithm,
234+
function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
235235
u0=prob.u0,p=prob.p,args...;kwargs...)
236236
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
237237
RecursiveArrayTools.VectorOfArray(sol.u)

src/tracker.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
value(x::Type{Tracker.TrackedReal{T}}) where T = T
2+
value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N}
3+
value(x::Tracker.TrackedReal) = x.data
4+
value(x::Tracker.TrackedArray) = x.data
5+
6+
@inline fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
7+
@inline Base.any(f::Function,x::Tracker.TrackedArray) = any(f,Tracker.data(x))
8+
9+
# Support adaptive with non-tracked time
10+
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t) where {N}
11+
sqrt(sum(abs2,value(u)) / length(u))
12+
end
13+
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t) where {N}
14+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
15+
end
16+
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t) where {N}
17+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
18+
end
19+
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t) = abs(value(u))
20+
21+
# Support TrackedReal time, don't drop tracking on the adaptivity there
22+
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t::Tracker.TrackedReal) where {N}
23+
sqrt(sum(abs2,u) / length(u))
24+
end
25+
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
26+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
27+
end
28+
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
29+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
30+
end
31+
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
32+
33+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
34+
sensealg=nothing,kwargs...)
35+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
36+
end
37+
38+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;
39+
sensealg=nothing,kwargs...)
40+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
41+
end
42+
43+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;
44+
sensealg=nothing,kwargs...)
45+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
46+
end
47+
48+
Tracker.@grad function concrete_solve(prob,alg,u0,p,args...;
49+
sensealg=nothing,kwargs...)
50+
_concrete_solve_adjoint(prob,alg,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
51+
end

0 commit comments

Comments
 (0)