Skip to content

Commit fc605c3

Browse files
Merge pull request #1463 from isaacsas/observed_with_jumps
add observed to JumpProblems
2 parents 7e9e832 + 8d93bbc commit fc605c3

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Reexport = "0.2, 1"
6868
Requires = "1.0"
6969
RuntimeGeneratedFunctions = "0.4.3, 0.5"
7070
SafeTestsets = "0.0.1"
71-
SciMLBase = "1.3"
71+
SciMLBase = "1.26.2"
7272
Setfield = "0.7, 0.8"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/systems/jumps/jumpsystem.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,24 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
219219
```
220220
"""
221221
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,Nothing},
222-
parammap=DiffEqBase.NullParameters(); kwargs...)
222+
parammap=DiffEqBase.NullParameters(); checkbounds=false, kwargs...)
223223
defs = defaults(sys)
224224
u0 = varmap_to_vars(u0map, states(sys); defaults=defs)
225225
p = varmap_to_vars(parammap, parameters(sys); defaults=defs)
226226
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
227-
df = DiscreteFunction{true,true}(f, syms=Symbol.(states(sys)))
227+
228+
# just taken from abstractodesystem.jl for ODEFunction def
229+
obs = observed(sys)
230+
observedfun = let sys = sys, dict = Dict()
231+
function generated_observed(obsvar, u, p, t)
232+
obs = get!(dict, value(obsvar)) do
233+
build_explicit_observed_function(sys, obsvar; checkbounds=checkbounds)
234+
end
235+
obs(u, p, t)
236+
end
237+
end
238+
239+
df = DiscreteFunction{true,true}(f, syms=Symbol.(states(sys)), observed=observedfun)
228240
DiscreteProblem(df, u0, tspan, p; kwargs...)
229241
end
230242

test/jumpsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ function getmean(jprob,Nsims)
7171
end
7272
m = getmean(jprob,Nsims)
7373

74+
@variables S2(t)
75+
obs = [S2 ~ 2*S]
76+
@named js2b = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ], observed=obs)
77+
dprob = DiscreteProblem(js2b, u₀map, tspan, parammap)
78+
jprob = JumpProblem(js2b, dprob, Direct(), save_positions=(false,false))
79+
sol = solve(jprob, SSAStepper(), saveat=tspan[2]/10)
80+
@test all(2 .* sol[S] .== sol[S2])
81+
7482
# test save_positions is working
7583
jprob = JumpProblem(js2, dprob, Direct(), save_positions=(false,false))
7684
sol = solve(jprob, SSAStepper(), saveat=1.0)

0 commit comments

Comments
 (0)