Skip to content

Commit 2e65685

Browse files
Merge pull request #49 from SciML/nondiag
fix wrapper handling of non-diagonal sparse noise
2 parents 584c82d + 1f4a5f7 commit 2e65685

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

src/integrators/interface.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ function DiffEqBase.auto_dt_reset!(integrator::SDDEIntegrator)
107107
end
108108

109109
# determine initial time step
110-
sde_prob = SDEProblem(f,g, prob.u0, prob.tspan, prob.p)
110+
sde_prob = SDEProblem(f,g, prob.u0, prob.tspan, prob.p;
111+
noise_rate_prototype = prob.noise_rate_prototype,
112+
noise = prob.noise,
113+
seed = prob.seed,
114+
prob.kwargs...)
111115
integrator.dt = StochasticDiffEq.sde_determine_initdt(integrator.u, integrator.t,
112116
integrator.tdir, integrator.opts.dtmax, integrator.opts.abstol, integrator.opts.reltol,
113117
integrator.opts.internalnorm, sde_prob, StochasticDiffEq.get_current_alg_order(getalg(integrator.alg), integrator.cache), integrator)
@@ -244,7 +248,7 @@ function DiffEqBase.u_modified!(integrator::SDDEIntegrator, bool::Bool)
244248
end
245249

246250
get_proposed_dt(integrator::SDDEIntegrator) = integrator.dtpropose
247-
set_proposed_dt!(integrator::SDDEIntegrator,dt::Number) = (integrator.dtpropose = dt)
251+
set_proposed_dt!(integrator::SDDEIntegrator,dt::Number) = (integrator.dtpropose = dt; integrator.dtcache = dt)
248252

249253
function set_proposed_dt!(integrator::SDDEIntegrator,integrator2::SDDEIntegrator)
250254
integrator.dtpropose = integrator2.dtpropose
@@ -337,3 +341,15 @@ end
337341
!isnothing(integrator.W) && DiffEqNoiseProcess.save_noise!(integrator.W)
338342
!isnothing(integrator.P) && DiffEqNoiseProcess.save_noise!(integrator.P)
339343
end
344+
345+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator) =
346+
get_tmp_cache(integrator, integrator.alg, integrator.cache)
347+
# avoid method ambiguity
348+
for typ in (StochasticDiffEq.StochasticDiffEqAlgorithm,StochasticDiffEq.StochasticDiffEqNewtonAdaptiveAlgorithm)
349+
@eval @inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::$typ, cache::StochasticDiffEq.StochasticDiffEqConstantCache) = nothing
350+
end
351+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg, cache) = (cache.tmp,)
352+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::StochasticDiffEq.StochasticDiffEqNewtonAdaptiveAlgorithm, cache) =
353+
(cache.nlsolver.tmp, cache.nlsolver.ztmp)
354+
@inline DiffEqBase.get_tmp_cache(integrator::SDDEIntegrator, alg::StochasticDiffEq.StochasticCompositeAlgorithm, cache) =
355+
get_tmp_cache(integrator, alg.algs[1], cache.caches[1])

test/nondiagonal_sparse_noise.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using StochasticDelayDiffEq
2+
using Random
3+
using SparseArrays
4+
5+
function sir_dde!(du,u,h,p,t)
6+
(S,I,R) = u
7+
(β,c,τ) = p
8+
N = S+I+R
9+
infection = β*c*I/N*S
10+
(Sd,Id,Rd) = h(p, t-τ) # Time delayed variables
11+
Nd = Sd+Id+Rd
12+
recovery = β*c*Id/Nd*Sd
13+
@inbounds begin
14+
du[1] = -infection
15+
du[2] = infection - recovery
16+
du[3] = recovery
17+
end
18+
nothing
19+
end;
20+
21+
# Define a sparse matrix by making a dense matrix and setting some values as not zero
22+
A = zeros(3,2)
23+
A[1,1] = 1
24+
A[2,1] = 1
25+
A[2,2] = 1
26+
A[3,2] = 1
27+
A = SparseArrays.sparse(A);
28+
29+
# Make `g` write the sparse matrix values
30+
function sir_delayed_noise!(du,u,h,p,t)
31+
(S,I,R) = u
32+
(β,c,τ) = p
33+
N = S+I+R
34+
infection = β*c*I/N*S
35+
(Sd,Id,Rd) = h(p, t-τ) # Time delayed variables
36+
Nd = Sd+Id+Rd
37+
recovery = β*c*Id/Nd*Sd
38+
du[1,1] = -sqrt(infection)
39+
du[2,1] = sqrt(infection)
40+
du[2,2] = -sqrt(recovery)
41+
du[3,2] = sqrt(recovery)
42+
end;
43+
44+
function condition(u,t,integrator) # Event when event_f(u,t) == 0
45+
u[2]
46+
end;
47+
function affect!(integrator)
48+
integrator.u[2] = 0.0
49+
end;
50+
cb = ContinuousCallback(condition,affect!);
51+
52+
δt = 0.1
53+
tmax = 40.0
54+
tspan = (0.0,tmax)
55+
t = 0.0:δt:tmax;
56+
u0 = [990.0,10.0,0.0]; # S,I,R
57+
58+
function sir_history(p, t)
59+
[1000.0, 0.0, 0.0]
60+
end;
61+
62+
p = [0.05,10.0,4.0]; # β,c,τ
63+
Random.seed!(1234);
64+
65+
prob_sdde = SDDEProblem(sir_dde!,sir_delayed_noise!,u0,sir_history,tspan,p;noise_rate_prototype=A);
66+
sol_sdde = solve(prob_sdde,LambaEM(),callback=cb);

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ using SafeTestsets
33
@safetestset "SDDEProblem, solve" begin include("test_prob_sol.jl") end
44
@safetestset "Analyticless Convergence Tests" begin include("analyticless_convergence_tests.jl") end
55
@safetestset "Event handling" begin include("events.jl") end
6+
@safetestset "Non-Diagonal Sparse Noise" begin include("nondiagonal_sparse_noise.jl") end

0 commit comments

Comments
 (0)