Skip to content

Commit 9f52080

Browse files
[WIP] Tracker and ReverseDiff custom concrete_solve gradients
MWE: ```julia using DiffEqSensitivity, OrdinaryDiffEq, Zygote using RecursiveArrayTools: DiffEqArray using Test, ForwardDiff using Tracker, ReverseDiff function fiip(du,u,p,t) du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2] du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2] end function foop(u,p,t) dx = p[1]*u[1] - p[2]*u[1]*u[2] dy = -p[3]*u[2] + p[4]*u[1]*u[2] [dx,dy] end function foop(u::Tracker.TrackedArray,p,t) dx = p[1]*u[1] - p[2]*u[1]*u[2] dy = -p[3]*u[2] + p[4]*u[1]*u[2] Tracker.collect([dx,dy]) end p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0] prob = ODEProblem(fiip,u0,(0.0,10.0),p) proboop = ODEProblem(foop,u0,(0.0,10.0),p) sol = concrete_solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) _sol = solve(prob,Tsit5(),abstol=1e-14,reltol=1e-14) ū0,adj = adjoint_sensitivities(_sol,Tsit5(),((out,u,p,t,i) -> out .= -1),0.0:0.1:10,abstol=1e-14, reltol=1e-14,iabstol=1e-14,ireltol=1e-12) du01,dp1 = Zygote.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p,abstol=1e-14,reltol=1e-14,saveat=0.1)),u0,p) Tracker.@Grad function concrete_solve(prob,alg,u0,p,args...; sensealg=nothing,kwargs...) @show "here" _concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...) end ReverseDiff.@Grad function concrete_solve(prob,alg,u0,p,args...; sensealg=nothing,kwargs...) @show "here" _concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...) end du01,dp1 = Tracker.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p,abstol=1e-14,reltol=1e-14,saveat=0.1)),u0,p) du01,dp1 = ReverseDiff.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p,abstol=1e-14,reltol=1e-14,saveat=0.1)),(u0,p)) Tracker.@Grad function concrete_solve(prob,alg,u0,p) @show "here" _concrete_solve_adjoint(prob,alg,sensealg,u0,p) end ReverseDiff.@Grad function concrete_solve(prob,alg,u0,p) @show "here" _concrete_solve_adjoint(prob,alg,sensealg,u0,p) end du01,dp1 = Tracker.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p)),u0,p) du01,dp1 = ReverseDiff.gradient((u0,p)->sum(concrete_solve(prob,Tsit5(),u0,p)),(u0,p)) ```
1 parent 200a6dc commit 9f52080

File tree

4 files changed

+46
-32
lines changed

4 files changed

+46
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.34.3"
4+
version = "6.35.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/init.jl

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -145,37 +145,7 @@ function __init__()
145145
end
146146

147147
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
148-
value(x::Type{Tracker.TrackedReal{T}}) where T = T
149-
value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N}
150-
value(x::Tracker.TrackedReal) = x.data
151-
value(x::Tracker.TrackedArray) = x.data
152-
153-
@inline fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
154-
@inline Base.any(f::Function,x::Tracker.TrackedArray) = any(f,Tracker.data(x))
155-
156-
# Support adaptive with non-tracked time
157-
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t) where {N}
158-
sqrt(sum(abs2,value(u)) / length(u))
159-
end
160-
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t) where {N}
161-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
162-
end
163-
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t) where {N}
164-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
165-
end
166-
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t) = abs(value(u))
167-
168-
# Support TrackedReal time, don't drop tracking on the adaptivity there
169-
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t::Tracker.TrackedReal) where {N}
170-
sqrt(sum(abs2,u) / length(u))
171-
end
172-
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
173-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
174-
end
175-
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
176-
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
177-
end
178-
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
148+
include("tracker.jl")
179149
end
180150

181151
# Piracy, should get upstreamed
@@ -213,6 +183,10 @@ function __init__()
213183
end
214184
end
215185

186+
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
187+
include("reversediff.jl")
188+
end
189+
216190
@require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin
217191
numargs(::GeneralizedGenerated.RuntimeFn{Args}) where Args = GeneralizedGenerated.from_type(Args) |> length
218192
end

src/reversediff.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
2+
sensealg=nothing,kwargs...)
3+
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
4+
end

src/tracker.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
value(x::Type{Tracker.TrackedReal{T}}) where T = T
2+
value(x::Type{Tracker.TrackedArray{T,N,A}}) where {T,N,A} = Array{T,N}
3+
value(x::Tracker.TrackedReal) = x.data
4+
value(x::Tracker.TrackedArray) = x.data
5+
6+
@inline fastpow(x::Tracker.TrackedReal, y::Tracker.TrackedReal) = x^y
7+
@inline Base.any(f::Function,x::Tracker.TrackedArray) = any(f,Tracker.data(x))
8+
9+
# Support adaptive with non-tracked time
10+
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t) where {N}
11+
sqrt(sum(abs2,value(u)) / length(u))
12+
end
13+
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t) where {N}
14+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
15+
end
16+
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t) where {N}
17+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
18+
end
19+
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t) = abs(value(u))
20+
21+
# Support TrackedReal time, don't drop tracking on the adaptivity there
22+
@inline function ODE_DEFAULT_NORM(u::Tracker.TrackedArray,t::Tracker.TrackedReal) where {N}
23+
sqrt(sum(abs2,u) / length(u))
24+
end
25+
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
26+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
27+
end
28+
@inline function ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal,N},t::Tracker.TrackedReal) where {N}
29+
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip(u,Iterators.repeated(t))) / length(u))
30+
end
31+
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
32+
33+
Tracker.@grad function concrete_solve(prob,alg,u0,p,args...;
34+
sensealg=nothing,kwargs...)
35+
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
36+
end

0 commit comments

Comments
 (0)