Skip to content

Commit 3e64aa5

Browse files
committed
Add observed equation to SDEfunction
1 parent 9cb6a4d commit 3e64aa5

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
221221
u0 = nothing;
222222
version = nothing, tgrad = false, sparse = false,
223223
jac = false, Wfact = false, eval_expression = true,
224+
checkbounds = false,
224225
kwargs...) where {iip}
225226
dvs = scalarize.(dvs)
226227
ps = scalarize.(ps)
@@ -279,14 +280,25 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
279280
M = calculate_massmatrix(sys)
280281
_M = (u0 === nothing || M == I) ? M : ArrayInterfaceCore.restructure(u0 .* u0', M)
281282

283+
obs = observed(sys)
284+
observedfun = let sys = sys, dict = Dict()
285+
function generated_observed(obsvar, u, p, t)
286+
obs = get!(dict, value(obsvar)) do
287+
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
288+
end
289+
obs(u, p, t)
290+
end
291+
end
292+
282293
sts = states(sys)
283294
SDEFunction{iip}(f, g,
284295
jac = _jac === nothing ? nothing : _jac,
285296
tgrad = _tgrad === nothing ? nothing : _tgrad,
286297
Wfact = _Wfact === nothing ? nothing : _Wfact,
287298
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
288299
mass_matrix = _M,
289-
syms = Symbol.(states(sys)))
300+
syms = Symbol.(states(sys)),
301+
observed = observedfun)
290302
end
291303

292304
function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)

test/sdesystem.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,37 @@ eqs = [D(x) ~ x]
479479
noiseeqs = [0.1 * x]
480480
@named de = SDESystem(eqs, noiseeqs, t, [x], [])
481481
@test nameof(rename(de, :newname)) == :newname
482+
483+
@testset "observed functionality" begin
484+
@parameters α β
485+
@variables t x(t) y(t) z(t)
486+
@variables weight(t)
487+
D = Differential(t)
488+
489+
eqs = [D(x) ~ α * x]
490+
noiseeqs =* x]
491+
dt = 1 // 2^(7)
492+
x0 = 0.1
493+
494+
u0map = [
495+
x => x0,
496+
]
497+
498+
parammap = [
499+
α => 1.5,
500+
β => 1.0,
501+
]
502+
503+
@named de = SDESystem(eqs, noiseeqs, t, [x], [α, β], observed = [weight ~ x * 10])
504+
505+
prob = SDEProblem(de, u0map, (0.0, 1.0), parammap)
506+
sol = solve(prob, EM(), dt = dt)
507+
@test observed(de) == [weight ~ x * 10]
508+
@test sol[weight] == 10 * sol[x]
509+
510+
@named ode = ODESystem(eqs, t, [x], [α, β], observed = [weight ~ x * 10])
511+
odeprob = ODEProblem(ode, u0map, (0.0, 1.0), parammap)
512+
solode = solve(odeprob, Tsit5())
513+
@test observed(ode) == [weight ~ x * 10]
514+
@test solode[weight] == 10 * solode[x]
515+
end

0 commit comments

Comments
 (0)