@@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
9393 """
9494 defaults:: Dict
9595 """
96+ The guesses to use as the initial conditions for the
97+ initialization system.
98+ """
99+ guesses:: Dict
100+ """
101+ The system for performing the initialization.
102+ """
103+ initializesystem:: Union{Nothing, NonlinearSystem}
104+ """
105+ Extra equations to be enforced during the initialization sequence.
106+ """
107+ initialization_eqs:: Vector{Equation}
108+ """
96109 Type of the system.
97110 """
98111 connector_type:: Any
@@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
144157 isscheduled:: Bool
145158
146159 function SDESystem (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
147- tgrad,
148- jac,
149- ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
160+ tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161+ guesses, initializesystem, initialization_eqs, connector_type,
150162 cevents, devents, parameter_dependencies, metadata = nothing , gui_metadata = nothing ,
151163 complete = false , index_cache = nothing , parent = nothing , is_scalar_noise = false ,
152164 is_dde = false ,
@@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
171183 check_units (u, deqs, neqs)
172184 end
173185 new (tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
174- ctrl_jac,
175- Wfact, Wfact_t, name, description, systems ,
176- defaults, connector_type, cevents, devents,
186+ ctrl_jac, Wfact, Wfact_t, name, description, systems,
187+ defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents ,
188+ devents,
177189 parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
178190 is_dde, isscheduled)
179191 end
@@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
187199 default_u0 = Dict (),
188200 default_p = Dict (),
189201 defaults = _merge (Dict (default_u0), Dict (default_p)),
202+ guesses = Dict (),
203+ initializesystem = nothing ,
204+ initialization_eqs = Equation[],
190205 name = nothing ,
191206 description = " " ,
192207 connector_type = nothing ,
@@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
207222 dvs′ = value .(dvs)
208223 ps′ = value .(ps)
209224 ctrl′ = value .(controls)
225+ parameter_dependencies, ps′ = process_parameter_dependencies (
226+ parameter_dependencies, ps′)
210227
211228 sysnames = nameof .(systems)
212229 if length (unique (sysnames)) != length (sysnames)
@@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
217234 " `default_u0` and `default_p` are deprecated. Use `defaults` instead." ,
218235 :SDESystem , force = true )
219236 end
220- defaults = todict (defaults)
221- defaults = Dict (value (k) => value (v)
222- for (k, v) in pairs (defaults) if value (v) != = nothing )
223237
238+ defaults = Dict {Any, Any} (todict (defaults))
239+ guesses = Dict {Any, Any} (todict (guesses))
224240 var_to_name = Dict ()
225- process_variables! (var_to_name, defaults, dvs′)
226- process_variables! (var_to_name, defaults, ps′)
241+ process_variables! (var_to_name, defaults, guesses, dvs′)
242+ process_variables! (var_to_name, defaults, guesses, ps′)
243+ process_variables! (
244+ var_to_name, defaults, guesses, [eq. lhs for eq in parameter_dependencies])
245+ process_variables! (
246+ var_to_name, defaults, guesses, [eq. rhs for eq in parameter_dependencies])
247+ defaults = Dict {Any, Any} (value (k) => value (v)
248+ for (k, v) in pairs (defaults) if v != = nothing )
249+ guesses = Dict {Any, Any} (value (k) => value (v)
250+ for (k, v) in pairs (guesses) if v != = nothing )
251+
227252 isempty (observed) || collect_var_to_name! (var_to_name, (eq. lhs for eq in observed))
228253
229254 tgrad = RefValue (EMPTY_TGRAD)
@@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
233258 Wfact_t = RefValue (EMPTY_JAC)
234259 cont_callbacks = SymbolicContinuousCallbacks (continuous_events)
235260 disc_callbacks = SymbolicDiscreteCallbacks (discrete_events)
236- parameter_dependencies, ps′ = process_parameter_dependencies (
237- parameter_dependencies, ps′)
238261 if is_dde === nothing
239262 is_dde = _check_if_dde (deqs, iv′, systems)
240263 end
241264 SDESystem (Threads. atomic_add! (SYSTEM_COUNT, UInt (1 )),
242265 deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
243- ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
266+ ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267+ initializesystem, initialization_eqs, connector_type,
244268 cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
245269 complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
246270end
@@ -520,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
520544 version = nothing , tgrad = false , sparse = false ,
521545 jac = false , Wfact = false , eval_expression = false ,
522546 eval_module = @__MODULE__ ,
523- checkbounds = false ,
547+ checkbounds = false , initialization_data = nothing ,
524548 kwargs... ) where {iip, specialize}
525549 if ! iscomplete (sys)
526550 error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`" )
@@ -591,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
591615
592616 observedfun = ObservedFunctionCache (sys; eval_expression, eval_module)
593617
594- SDEFunction {iip, specialize} (f, g,
618+ SDEFunction {iip, specialize} (f, g;
595619 sys = sys,
596620 jac = _jac === nothing ? nothing : _jac,
597621 tgrad = _tgrad === nothing ? nothing : _tgrad,
598622 Wfact = _Wfact === nothing ? nothing : _Wfact,
599623 Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
600- mass_matrix = _M,
624+ mass_matrix = _M, initialization_data,
601625 observed = observedfun)
602626end
603627
@@ -714,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
714738 end
715739 f, u0, p = process_SciMLProblem (
716740 SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
717- kwargs... )
741+ t = tspan === nothing ? nothing : tspan[ 1 ], kwargs... )
718742 cbs = process_events (sys; callback, kwargs... )
719743 sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
720744
@@ -736,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
736760 noise = nothing
737761 end
738762
763+ kwargs = filter_kwargs (kwargs)
764+
739765 SDEProblem {iip} (f, u0, tspan, p; callback = cbs, noise,
740766 noise_rate_prototype = noise_rate_prototype, kwargs... )
741767end
0 commit comments