Skip to content

Commit 4c1ae36

Browse files
Merge pull request #84 from AayushSabharwal/as/initialization
feat: add initialization support
2 parents 4912067 + 662e455 commit 4c1ae36

File tree

4 files changed

+15
-18
lines changed

4 files changed

+15
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Reexport = "1.0"
3636
SciMLBase = "2.59.2"
3737
SparseArrays = "1.9"
3838
StaticArrays = "1.0"
39-
StochasticDiffEq = "6.19"
39+
StochasticDiffEq = "6.72.1"
4040
UnPack = "0.1, 1.0"
4141
julia = "1.9"
4242

src/functionwrapper.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99
(g::SDEDiffusionTermWrapper{false})(u, p, t) = g.g(u, g.h, p, t)
1010

1111
struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, GG,
12-
TCV} <: DiffEqBase.AbstractRODEFunction{iip}
12+
TCV, ID, S} <: DiffEqBase.AbstractRODEFunction{iip}
1313
f::F
1414
g::G
1515
h::H
@@ -26,6 +26,8 @@ struct SDEFunctionWrapper{iip, F, G, H, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, T
2626
paramjac::TPJ
2727
ggprime::GG
2828
colorvec::TCV
29+
initialization_data::ID
30+
sys::S
2931
end
3032

3133
(f::SDEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
@@ -53,17 +55,9 @@ function wrap_functions_and_history(f::SDDEFunction, g, h)
5355
typeof(f.analytic), typeof(f.tgrad), typeof(jac), typeof(f.jvp),
5456
typeof(f.vjp), typeof(f.jac_prototype), typeof(f.sparsity),
5557
typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.paramjac),
56-
typeof(f.ggprime), typeof(f.colorvec)}(f.f, gwh, h,
57-
f.mass_matrix,
58-
f.analytic,
59-
f.tgrad, jac,
60-
f.jvp, f.vjp,
61-
f.jac_prototype,
62-
f.sparsity,
63-
f.Wfact,
64-
f.Wfact_t,
65-
f.paramjac,
66-
f.ggprime,
67-
f.colorvec),
58+
typeof(f.ggprime), typeof(f.colorvec), typeof(f.initialization_data),
59+
typeof(f.sys)}(f.f, gwh, h, f.mass_matrix, f.analytic, f.tgrad, jac,
60+
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t, f.paramjac,
61+
f.ggprime, f.colorvec, f.initialization_data, f.sys),
6862
gwh
6963
end

src/integrators/type.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
mutable struct SDDEIntegrator{algType, IIP, uType, uEltype, tType, P, eigenType,
2727
tTypeNoUnits, uEltypeNoUnits, randType, randType2, rateType,
2828
solType, cacheType, F, G, F6, OType, noiseType,
29-
EventErrorType, CallbackCacheType, H, IType} <:
29+
EventErrorType, CallbackCacheType, H, IType, IA} <:
3030
AbstractSDDEIntegrator{algType, IIP, uType, tType}
3131
f::F
3232
g::G
@@ -81,6 +81,7 @@ mutable struct SDDEIntegrator{algType, IIP, uType, uEltype, tType, P, eigenType,
8181
history::H
8282
stats::DiffEqBase.Stats
8383
integrator::IType # history integrator
84+
initializealg::IA
8485
end
8586

8687
function (integrator::SDDEIntegrator)(t, deriv::Type = Val{0}; idxs = nothing)

src/solve.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
6666
# Keywords for Delay problems (from DDE)
6767
discontinuity_interp_points::Int = 10,
6868
discontinuity_abstol = eltype(prob.tspan)(1 // Int64(10)^12),
69-
discontinuity_reltol = 0, kwargs...) where {recompile_flag}
69+
discontinuity_reltol = 0,
70+
initializealg = StochasticDiffEq.SDEDefaultInit(), kwargs...) where {recompile_flag}
7071

7172
# alg = getalg(alg0);
7273
if prob.f isa Tuple
@@ -468,7 +469,7 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
468469
typeof(c),
469470
typeof(opts), typeof(noise), typeof(last_event_error),
470471
typeof(callback_cache), typeof(history),
471-
typeof(sde_integrator)}(f_with_history,
472+
typeof(sde_integrator), typeof(initializealg)}(f_with_history,
472473
g_with_history, c, noise, uprev,
473474
tprev,
474475
order_discontinuity_t0,
@@ -486,9 +487,10 @@ function DiffEqBase.__init(prob::AbstractSDDEProblem,# TODO DiffEqBasee.Abstract
486487
P,
487488
opts, iter, success_iter, eigen_est,
488489
EEst, q, QT(qoldinit), q11, history,
489-
stats, sde_integrator)
490+
stats, sde_integrator, initializealg)
490491

491492
if initialize_integrator
493+
DiffEqBase.initialize_dae!(integrator)
492494
StochasticDiffEq.initialize_callbacks!(integrator, initialize_save)
493495
initialize!(integrator, integrator.cache)
494496

0 commit comments

Comments
 (0)