@@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
231231 if isdde
232232 eqs = delay_to_function (sys, eqs)
233233 end
234+ if eqs isa AbstractMatrix && isdiag (eqs)
235+ eqs = diag (eqs)
236+ end
234237 u = map (x -> time_varying_as_func (value (x), sys), dvs)
235238 p = if has_index_cache (sys) && get_index_cache (sys) != = nothing
236239 reorder_parameters (get_index_cache (sys), ps)
@@ -403,14 +406,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
403406 checks = false )
404407end
405408
406- function DiffEqBase. SDEFunction {iip} (sys:: SDESystem , dvs = unknowns (sys),
409+ function DiffEqBase. SDEFunction {iip, specialize } (sys:: SDESystem , dvs = unknowns (sys),
407410 ps = parameters (sys),
408411 u0 = nothing ;
409412 version = nothing , tgrad = false , sparse = false ,
410413 jac = false , Wfact = false , eval_expression = false ,
411414 eval_module = @__MODULE__ ,
412415 checkbounds = false ,
413- kwargs... ) where {iip}
416+ kwargs... ) where {iip, specialize }
414417 if ! iscomplete (sys)
415418 error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`" )
416419 end
@@ -480,7 +483,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
480483
481484 observedfun = ObservedFunctionCache (sys; eval_expression, eval_module)
482485
483- SDEFunction {iip} (f, g,
486+ SDEFunction {iip, specialize } (f, g,
484487 sys = sys,
485488 jac = _jac === nothing ? nothing : _jac,
486489 tgrad = _tgrad === nothing ? nothing : _tgrad,
@@ -505,6 +508,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
505508 SDEFunction {true} (sys, args... ; kwargs... )
506509end
507510
511+ function DiffEqBase. SDEFunction {true} (sys:: SDESystem , args... ;
512+ kwargs... )
513+ SDEFunction {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
514+ end
515+
516+ function DiffEqBase. SDEFunction {false} (sys:: SDESystem , args... ;
517+ kwargs... )
518+ SDEFunction {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
519+ end
520+
508521"""
509522```julia
510523DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -583,14 +596,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
583596 SDEFunctionExpr {true} (sys, args... ; kwargs... )
584597end
585598
586- function DiffEqBase. SDEProblem {iip} (sys:: SDESystem , u0map = [], tspan = get_tspan (sys),
599+ function DiffEqBase. SDEProblem {iip, specialize} (
600+ sys:: SDESystem , u0map = [], tspan = get_tspan (sys),
587601 parammap = DiffEqBase. NullParameters ();
588602 sparsenoise = nothing , check_length = true ,
589- callback = nothing , kwargs... ) where {iip}
603+ callback = nothing , kwargs... ) where {iip, specialize }
590604 if ! iscomplete (sys)
591605 error (" A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`" )
592606 end
593- f, u0, p = process_DEProblem (SDEFunction{iip}, sys, u0map, parammap; check_length,
607+ f, u0, p = process_DEProblem (
608+ SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
594609 kwargs... )
595610 cbs = process_events (sys; callback, kwargs... )
596611 sparsenoise === nothing && (sparsenoise = get (kwargs, :sparse , false ))
@@ -628,6 +643,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
628643 SDEProblem {true} (sys, args... ; kwargs... )
629644end
630645
646+ function DiffEqBase. SDEProblem (sys:: SDESystem ,
647+ u0map:: StaticArray ,
648+ args... ;
649+ kwargs... )
650+ SDEProblem {false, SciMLBase.FullSpecialize} (sys, u0map, args... ; kwargs... )
651+ end
652+
653+ function DiffEqBase. SDEProblem {true} (sys:: SDESystem , args... ; kwargs... )
654+ SDEProblem {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
655+ end
656+
657+ function DiffEqBase. SDEProblem {false} (sys:: SDESystem , args... ; kwargs... )
658+ SDEProblem {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
659+ end
660+
631661"""
632662```julia
633663DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
0 commit comments