Skip to content

Commit 76e2b22

Browse files
fix wrapper handling of non-diagonal sparse noise
Fixes #48
1 parent 584c82d commit 76e2b22

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

src/integrators/interface.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ function DiffEqBase.auto_dt_reset!(integrator::SDDEIntegrator)
107107
end
108108

109109
# determine initial time step
110-
sde_prob = SDEProblem(f,g, prob.u0, prob.tspan, prob.p)
110+
sde_prob = SDEProblem(f,g, prob.u0, prob.tspan, prob.p;
111+
noise_rate_prototype = prob.noise_rate_prototype,
112+
noise = prob.noise,
113+
seed = prob.seed,
114+
prob.kwargs...)
111115
integrator.dt = StochasticDiffEq.sde_determine_initdt(integrator.u, integrator.t,
112116
integrator.tdir, integrator.opts.dtmax, integrator.opts.abstol, integrator.opts.reltol,
113117
integrator.opts.internalnorm, sde_prob, StochasticDiffEq.get_current_alg_order(getalg(integrator.alg), integrator.cache), integrator)
@@ -244,7 +248,7 @@ function DiffEqBase.u_modified!(integrator::SDDEIntegrator, bool::Bool)
244248
end
245249

246250
get_proposed_dt(integrator::SDDEIntegrator) = integrator.dtpropose
247-
set_proposed_dt!(integrator::SDDEIntegrator,dt::Number) = (integrator.dtpropose = dt)
251+
set_proposed_dt!(integrator::SDDEIntegrator,dt::Number) = (integrator.dtpropose = dt; integrator.dtcache = dt)
248252

249253
function set_proposed_dt!(integrator::SDDEIntegrator,integrator2::SDDEIntegrator)
250254
integrator.dtpropose = integrator2.dtpropose
@@ -337,3 +341,15 @@ end
337341
!isnothing(integrator.W) && DiffEqNoiseProcess.save_noise!(integrator.W)
338342
!isnothing(integrator.P) && DiffEqNoiseProcess.save_noise!(integrator.P)
339343
end
344+
345+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator) =
346+
get_tmp_cache(integrator, integrator.alg, integrator.cache)
347+
# avoid method ambiguity
348+
for typ in (StochasticDiffEq.StochasticDiffEqAlgorithm,StochasticDiffEq.StochasticDiffEqNewtonAdaptiveAlgorithm)
349+
@eval @inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::$typ, cache::StochasticDiffEq.StochasticDiffEqConstantCache) = nothing
350+
end
351+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg, cache) = (cache.tmp,)
352+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::StochasticDiffEq.StochasticDiffEqNewtonAdaptiveAlgorithm, cache) =
353+
(cache.nlsolver.tmp, cache.nlsolver.ztmp)
354+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::StochasticDiffEq.StochasticCompositeAlgorithm, cache) =
355+
get_tmp_cache(integrator, alg.algs[1], cache.caches[1])

test/nondiagonal_sparse_noise.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using StochasticDelayDiffEq
2+
using DiffEqCallbacks
3+
using Random
4+
using SparseArrays
5+
6+
function sir_dde!(du,u,h,p,t)
7+
(S,I,R) = u
8+
(β,c,τ) = p
9+
N = S+I+R
10+
infection = β*c*I/N*S
11+
(Sd,Id,Rd) = h(p, t-τ) # Time delayed variables
12+
Nd = Sd+Id+Rd
13+
recovery = β*c*Id/Nd*Sd
14+
@inbounds begin
15+
du[1] = -infection
16+
du[2] = infection - recovery
17+
du[3] = recovery
18+
end
19+
nothing
20+
end;
21+
22+
# Define a sparse matrix by making a dense matrix and setting some values as not zero
23+
A = zeros(3,2)
24+
A[1,1] = 1
25+
A[2,1] = 1
26+
A[2,2] = 1
27+
A[3,2] = 1
28+
A = SparseArrays.sparse(A);
29+
30+
# Make `g` write the sparse matrix values
31+
function sir_delayed_noise!(du,u,h,p,t)
32+
(S,I,R) = u
33+
(β,c,τ) = p
34+
N = S+I+R
35+
infection = β*c*I/N*S
36+
(Sd,Id,Rd) = h(p, t-τ) # Time delayed variables
37+
Nd = Sd+Id+Rd
38+
recovery = β*c*Id/Nd*Sd
39+
du[1,1] = -sqrt(infection)
40+
du[2,1] = sqrt(infection)
41+
du[2,2] = -sqrt(recovery)
42+
du[3,2] = sqrt(recovery)
43+
end;
44+
45+
function condition(u,t,integrator) # Event when event_f(u,t) == 0
46+
u[2]
47+
end;
48+
function affect!(integrator)
49+
integrator.u[2] = 0.0
50+
end;
51+
cb = ContinuousCallback(condition,affect!);
52+
53+
δt = 0.1
54+
tmax = 40.0
55+
tspan = (0.0,tmax)
56+
t = 0.0:δt:tmax;
57+
u0 = [990.0,10.0,0.0]; # S,I,R
58+
59+
function sir_history(p, t)
60+
[1000.0, 0.0, 0.0]
61+
end;
62+
63+
p = [0.05,10.0,4.0]; # β,c,τ
64+
Random.seed!(1234);
65+
66+
prob_sdde = SDDEProblem(sir_dde!,sir_delayed_noise!,u0,sir_history,tspan,p;noise_rate_prototype=A);
67+
sol_sdde = solve(prob_sdde,LambaEM(),callback=cb);

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ using SafeTestsets
33
@safetestset "SDDEProblem, solve" begin include("test_prob_sol.jl") end
44
@safetestset "Analyticless Convergence Tests" begin include("analyticless_convergence_tests.jl") end
55
@safetestset "Event handling" begin include("events.jl") end
6+
@safetestset "Non-Diagonal Sparse Noise" begin include("nondiagonal_sparse_noise.jl") end

0 commit comments

Comments
 (0)