Skip to content

Commit fabef51

Browse files
committed
callback hotfix
1 parent d1c3f04 commit fabef51

File tree

3 files changed

+158
-3
lines changed

3 files changed

+158
-3
lines changed

src/callbacks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,12 +765,12 @@ end
765765

766766
function apply_callback!(integrator,callback::Union{ContinuousCallback,VectorContinuousCallback},cb_time,prev_sign,event_idx)
767767

768-
change_t_via_interpolation!(integrator,integrator.tprev+cb_time)
769-
770768
if integrator.opts.adaptive
771-
set_proposed_dt!(integrator, max(integrator.opts.dtmin+eps(integrator.dt), callback.dtrelax * integrator.dt))
769+
set_proposed_dt!(integrator, max(nextfloat(integrator.opts.dtmin), callback.dtrelax * integrator.dt))
772770
end
773771

772+
change_t_via_interpolation!(integrator,integrator.tprev+cb_time)
773+
774774
# handle saveat
775775
_, savedexactly = savevalues!(integrator)
776776
saved_in_cb = true

test/downstream/ad_tests.jl

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using Test
2+
using OrdinaryDiffEq, Calculus, ForwardDiff
3+
4+
function f(du,u,p,t)
5+
du[1] = -p[1]
6+
du[2] = p[2]
7+
end
8+
9+
for x in 0:0.001:5
10+
called = false
11+
function test_f(p)
12+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
13+
prob = ODEProblem(f,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
14+
integrator = init(prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb)
15+
step!(integrator)
16+
solve!(integrator).u[end]
17+
end
18+
p = [2.0, x]
19+
called = false
20+
findiff = Calculus.finite_difference_jacobian(test_f,p)
21+
@test called
22+
called = false
23+
fordiff = ForwardDiff.jacobian(test_f,p)
24+
@test called
25+
@test findiff fordiff
26+
end
27+
28+
function f2(du,u,p,t)
29+
du[1] = -u[2]
30+
du[2] = p[2]
31+
end
32+
33+
for x in 2.1:0.001:5
34+
called = false
35+
function test_f2(p)
36+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
37+
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
38+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
39+
step!(integrator)
40+
solve!(integrator).u[end]
41+
end
42+
p = [2.0, x]
43+
findiff = Calculus.finite_difference_jacobian(test_f2,p)
44+
@test called
45+
called = false
46+
fordiff = ForwardDiff.jacobian(test_f2,p)
47+
@test called
48+
@test findiff fordiff
49+
end
50+
51+
#=
52+
#x = 2.0 is an interesting case
53+
54+
x = 2.0
55+
56+
function test_f2(p)
57+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(@show(x,integrator.t);called=true;integrator.p[2]=zero(integrator.p[2])))
58+
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
59+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
60+
step!(integrator)
61+
solve!(integrator).u[end]
62+
end
63+
64+
p = [2.0, x]
65+
findiff = Calculus.finite_difference_jacobian(test_f2,p)
66+
@test called
67+
called = false
68+
fordiff = ForwardDiff.jacobian(test_f2,p)
69+
@test called
70+
71+
# At that value, it shouldn't be called, but a small perturbation will make it called, so finite difference is wrong!
72+
=#
73+
74+
for x in 1.0:0.001:2.5
75+
function lotka_volterra(du,u,p,t)
76+
x, y = u
77+
α, β, δ, γ = p
78+
du[1] = dx = α*x - β*x*y
79+
du[2] = dy = -δ*y + γ*x*y
80+
end
81+
u0 = [1.0,1.0]
82+
tspan = (0.0,10.0)
83+
p = [x,1.0,3.0,1.0]
84+
prob = ODEProblem(lotka_volterra,u0,tspan,p)
85+
sol = solve(prob,Tsit5())
86+
87+
called=false
88+
function test_lotka(p)
89+
cb = ContinuousCallback((u,t,i) -> u[1]-2.5, (integrator)->(called=true;integrator.p[4]=1.5))
90+
prob = ODEProblem(lotka_volterra,eltype(p).([1.0,1.0]),eltype(p).((0.0,10.0)),copy(p))
91+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
92+
step!(integrator)
93+
solve!(integrator).u[end]
94+
end
95+
96+
findiff = Calculus.finite_difference_jacobian(test_lotka,p)
97+
@test called
98+
called = false
99+
fordiff = ForwardDiff.jacobian(test_lotka,p)
100+
@test called
101+
@test findiff fordiff
102+
end
103+
104+
# Gradients and Hessians
105+
106+
function myobj(θ)
107+
f(u,p,t) = -θ[1]*u
108+
u0, _ = promote(10.0, θ[1])
109+
prob = ODEProblem(f, u0, (0.0, 1.0))
110+
sol = solve(prob, Tsit5())
111+
diff = sol.u - 10*exp.(-sol.t)
112+
return diff'diff
113+
end
114+
115+
ForwardDiff.gradient(myobj, [1.0])
116+
ForwardDiff.hessian(myobj, [1.0])
117+
118+
function myobj2(θ)
119+
f(du,u,p,t) = (du[1]=-θ[1]*u[1])
120+
u0, _ = promote(10.0, θ[1])
121+
prob = ODEProblem(f, [u0], (0.0, 1.0))
122+
sol = solve(prob, Tsit5())
123+
diff = sol[:,1] .- 10 .*exp.(-sol.t)
124+
return diff'diff
125+
end
126+
127+
ForwardDiff.gradient(myobj2, [1.0])
128+
ForwardDiff.hessian(myobj2, [1.0])
129+
130+
function myobj3(θ)
131+
f(u,p,t) = -θ[1]*u
132+
u0, _ = promote(10.0, θ[1])
133+
tspan_end, _ = promote(1.0, θ[1])
134+
prob = ODEProblem(f, u0, (0.0, tspan_end))
135+
sol = solve(prob, Tsit5())
136+
diff = sol.u - 10*exp.(-sol.t)
137+
return diff'diff
138+
end
139+
140+
ForwardDiff.gradient(myobj3, [1.0])
141+
ForwardDiff.hessian(myobj3, [1.0])
142+
143+
function myobj4(θ)
144+
f(du,u,p,t) = (du[1] = -θ[1]*u[1])
145+
u0, _ = promote(10.0, θ[1])
146+
tspan_end, _ = promote(1.0, θ[1])
147+
prob = ODEProblem(f, [u0], (0.0, tspan_end))
148+
sol = solve(prob, Tsit5())
149+
diff = sol[:,1] .- 10 .* exp.(-sol.t)
150+
return diff'diff
151+
end
152+
153+
ForwardDiff.gradient(myobj4, [1.0])
154+
ForwardDiff.hessian(myobj4, [1.0])

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4646
@time @safetestset "PSOS and Energy Conservation Event Detection" begin include("downstream/psos_and_energy_conservation.jl") end
4747
@time @safetestset "DE stats" begin include("downstream/destats_tests.jl") end
4848
@time @safetestset "DEDataArray" begin include("downstream/data_array_regression_tests.jl") end
49+
@time @safetestset "AD Tests" begin include("downstream/ad_tests.jl") end
4950
end
5051

5152
if !is_APPVEYOR && GROUP == "GPU"

0 commit comments

Comments
 (0)