Skip to content

Commit 562205d

Browse files
fix Tracker and ReverseDiff
1 parent 1e15b33 commit 562205d

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

src/reversediff.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
2+
sensealg=nothing,kwargs...)
3+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
4+
end
5+
6+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
7+
sensealg=nothing,kwargs...)
8+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
9+
end
10+
11+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
12+
sensealg=nothing,kwargs...)
13+
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
14+
end
15+
116
ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
217
sensealg=nothing,kwargs...)
3-
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
18+
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
19+
Array(out[1]),out[2]
420
end

src/solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,18 +220,18 @@ end
220220

221221
function _concrete_solve end
222222

223-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,
223+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
224224
u0=prob.u0,p=prob.p,args...;kwargs...)
225225
_concrete_solve(prob,alg,u0,p,args...;kwargs...)
226226
end
227227

228-
function _concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,
228+
function _concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
229229
u0=prob.u0,p=prob.p,args...;kwargs...)
230230
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
231231
RecursiveArrayTools.DiffEqArray(sol.u,sol.t)
232232
end
233233

234-
function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::DiffEqBase.DEAlgorithm,
234+
function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
235235
u0=prob.u0,p=prob.p,args...;kwargs...)
236236
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
237237
RecursiveArrayTools.VectorOfArray(sol.u)

src/tracker.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ end
3030
end
3131
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
3232

33-
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
33+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
3434
sensealg=nothing,kwargs...)
3535
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
3636
end
3737

38-
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0,p::Tracker.TrackedArray,args...;
38+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;
3939
sensealg=nothing,kwargs...)
4040
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
4141
end
4242

43-
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0::Tracker.TrackedArray,p,args...;
43+
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;
4444
sensealg=nothing,kwargs...)
4545
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
4646
end

0 commit comments

Comments
 (0)