Skip to content

Commit 1e15b33

Browse files
Tracker works
1 parent 9f52080 commit 1e15b33

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

src/tracker.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,22 @@ end
3030
end
3131
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
3232

33+
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
34+
sensealg=nothing,kwargs...)
35+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
36+
end
37+
38+
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0,p::Tracker.TrackedArray,args...;
39+
sensealg=nothing,kwargs...)
40+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
41+
end
42+
43+
function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem,alg::DiffEqBase.DEAlgorithm,u0::Tracker.TrackedArray,p,args...;
44+
sensealg=nothing,kwargs...)
45+
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
46+
end
47+
3348
Tracker.@grad function concrete_solve(prob,alg,u0,p,args...;
3449
sensealg=nothing,kwargs...)
35-
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
50+
_concrete_solve_adjoint(prob,alg,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
3651
end

0 commit comments

Comments
 (0)