Skip to content

Commit a69144b

Browse files
Merge pull request #490 from SciML/kg/cbfix3
Callback hotfix
2 parents 6b86f42 + ee0c100 commit a69144b

File tree

4 files changed

+160
-3
lines changed

4 files changed

+160
-3
lines changed

src/callbacks.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,12 +765,12 @@ end
765765

766766
function apply_callback!(integrator,callback::Union{ContinuousCallback,VectorContinuousCallback},cb_time,prev_sign,event_idx)
767767

768-
change_t_via_interpolation!(integrator,integrator.tprev+cb_time)
769-
770768
if integrator.opts.adaptive
771-
set_proposed_dt!(integrator, max(integrator.opts.dtmin+eps(integrator.dt), callback.dtrelax * integrator.dt))
769+
set_proposed_dt!(integrator, max(nextfloat(integrator.opts.dtmin), callback.dtrelax * integrator.dt))
772770
end
773771

772+
change_t_via_interpolation!(integrator,integrator.tprev+cb_time)
773+
774774
# handle saveat
775775
_, savedexactly = savevalues!(integrator)
776776
saved_in_cb = true

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
2+
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
23
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
34
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
45
DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9"
6+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
57
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
68
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
79
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

test/downstream/ad_tests.jl

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using Test
2+
using OrdinaryDiffEq, Calculus, ForwardDiff
3+
4+
function f(du,u,p,t)
5+
du[1] = -p[1]
6+
du[2] = p[2]
7+
end
8+
9+
for x in 0:0.001:5
10+
called = false
11+
function test_f(p)
12+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
13+
prob = ODEProblem(f,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
14+
integrator = init(prob,Tsit5(),abstol=1e-14,reltol=1e-14,callback=cb)
15+
step!(integrator)
16+
solve!(integrator).u[end]
17+
end
18+
p = [2.0, x]
19+
called = false
20+
findiff = Calculus.finite_difference_jacobian(test_f,p)
21+
@test called
22+
called = false
23+
fordiff = ForwardDiff.jacobian(test_f,p)
24+
@test called
25+
@test findiff fordiff
26+
end
27+
28+
function f2(du,u,p,t)
29+
du[1] = -u[2]
30+
du[2] = p[2]
31+
end
32+
33+
for x in 2.1:0.001:5
34+
called = false
35+
function test_f2(p)
36+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(called=true;integrator.p[2]=zero(integrator.p[2])))
37+
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
38+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
39+
step!(integrator)
40+
solve!(integrator).u[end]
41+
end
42+
p = [2.0, x]
43+
findiff = Calculus.finite_difference_jacobian(test_f2,p)
44+
@test called
45+
called = false
46+
fordiff = ForwardDiff.jacobian(test_f2,p)
47+
@test called
48+
@test findiff fordiff
49+
end
50+
51+
#=
52+
#x = 2.0 is an interesting case
53+
54+
x = 2.0
55+
56+
function test_f2(p)
57+
cb = ContinuousCallback((u,t,i) -> u[1], (integrator)->(@show(x,integrator.t);called=true;integrator.p[2]=zero(integrator.p[2])))
58+
prob = ODEProblem(f2,eltype(p).([1.0,0.0]),eltype(p).((0.0,1.0)),copy(p))
59+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
60+
step!(integrator)
61+
solve!(integrator).u[end]
62+
end
63+
64+
p = [2.0, x]
65+
findiff = Calculus.finite_difference_jacobian(test_f2,p)
66+
@test called
67+
called = false
68+
fordiff = ForwardDiff.jacobian(test_f2,p)
69+
@test called
70+
71+
# At that value, it shouldn't be called, but a small perturbation will make it called, so finite difference is wrong!
72+
=#
73+
74+
for x in 1.0:0.001:2.5
75+
function lotka_volterra(du,u,p,t)
76+
x, y = u
77+
α, β, δ, γ = p
78+
du[1] = dx = α*x - β*x*y
79+
du[2] = dy = -δ*y + γ*x*y
80+
end
81+
u0 = [1.0,1.0]
82+
tspan = (0.0,10.0)
83+
p = [x,1.0,3.0,1.0]
84+
prob = ODEProblem(lotka_volterra,u0,tspan,p)
85+
sol = solve(prob,Tsit5())
86+
87+
called=false
88+
function test_lotka(p)
89+
cb = ContinuousCallback((u,t,i) -> u[1]-2.5, (integrator)->(called=true;integrator.p[4]=1.5))
90+
prob = ODEProblem(lotka_volterra,eltype(p).([1.0,1.0]),eltype(p).((0.0,10.0)),copy(p))
91+
integrator = init(prob,Tsit5(),abstol=1e-12,reltol=1e-12,callback=cb)
92+
step!(integrator)
93+
solve!(integrator).u[end]
94+
end
95+
96+
findiff = Calculus.finite_difference_jacobian(test_lotka,p)
97+
@test called
98+
called = false
99+
fordiff = ForwardDiff.jacobian(test_lotka,p)
100+
@test called
101+
@test findiff fordiff
102+
end
103+
104+
# Gradients and Hessians
105+
106+
function myobj(θ)
107+
f(u,p,t) = -θ[1]*u
108+
u0, _ = promote(10.0, θ[1])
109+
prob = ODEProblem(f, u0, (0.0, 1.0))
110+
sol = solve(prob, Tsit5())
111+
diff = sol.u - 10*exp.(-sol.t)
112+
return diff'diff
113+
end
114+
115+
ForwardDiff.gradient(myobj, [1.0])
116+
ForwardDiff.hessian(myobj, [1.0])
117+
118+
function myobj2(θ)
119+
f(du,u,p,t) = (du[1]=-θ[1]*u[1])
120+
u0, _ = promote(10.0, θ[1])
121+
prob = ODEProblem(f, [u0], (0.0, 1.0))
122+
sol = solve(prob, Tsit5())
123+
diff = sol[:,1] .- 10 .*exp.(-sol.t)
124+
return diff'diff
125+
end
126+
127+
ForwardDiff.gradient(myobj2, [1.0])
128+
ForwardDiff.hessian(myobj2, [1.0])
129+
130+
function myobj3(θ)
131+
f(u,p,t) = -θ[1]*u
132+
u0, _ = promote(10.0, θ[1])
133+
tspan_end, _ = promote(1.0, θ[1])
134+
prob = ODEProblem(f, u0, (0.0, tspan_end))
135+
sol = solve(prob, Tsit5())
136+
diff = sol.u - 10*exp.(-sol.t)
137+
return diff'diff
138+
end
139+
140+
ForwardDiff.gradient(myobj3, [1.0])
141+
ForwardDiff.hessian(myobj3, [1.0])
142+
143+
function myobj4(θ)
144+
f(du,u,p,t) = (du[1] = -θ[1]*u[1])
145+
u0, _ = promote(10.0, θ[1])
146+
tspan_end, _ = promote(1.0, θ[1])
147+
prob = ODEProblem(f, [u0], (0.0, tspan_end))
148+
sol = solve(prob, Tsit5())
149+
diff = sol[:,1] .- 10 .* exp.(-sol.t)
150+
return diff'diff
151+
end
152+
153+
ForwardDiff.gradient(myobj4, [1.0])
154+
ForwardDiff.hessian(myobj4, [1.0])

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4747
@time @safetestset "DE stats" begin include("downstream/destats_tests.jl") end
4848
@time @safetestset "DEDataArray" begin include("downstream/data_array_regression_tests.jl") end
4949
@time @safetestset "Concrete_solve Tests" begin include("downstream/concrete_solve_tests.jl") end
50+
@time @safetestset "AD Tests" begin include("downstream/ad_tests.jl") end
5051
end
5152

5253
if !is_APPVEYOR && GROUP == "GPU"

0 commit comments

Comments
 (0)