Skip to content

Commit 911333d

Browse files
committed
fix: fix several tests
1 parent 8aeff55 commit 911333d

File tree

8 files changed

+59
-58
lines changed

8 files changed

+59
-58
lines changed

src/systems/callbacks.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,12 @@ struct SymbolicContinuousCallback <: AbstractCallback
217217
function SymbolicContinuousCallback(
218218
conditions::Union{Equation, Vector{Equation}},
219219
affect = nothing;
220-
discrete_parameters = Any[],
221220
affect_neg = affect,
222221
initialize = nothing,
223222
finalize = nothing,
224223
rootfind = SciMLBase.LeftRootFind,
225224
reinitializealg = nothing,
226-
iv = nothing,
227-
alg_eqs = Equation[])
225+
kwargs...)
228226
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
229227

230228
if isnothing(reinitializealg)
@@ -233,11 +231,12 @@ struct SymbolicContinuousCallback <: AbstractCallback
233231
reinitializealg = SciMLBase.CheckInit() :
234232
reinitializealg = SciMLBase.NoInit()
235233
end
234+
@show kwargs
236235

237-
new(conditions, make_affect(affect; iv, alg_eqs, discrete_parameters),
238-
make_affect(affect_neg; iv, alg_eqs, discrete_parameters),
239-
make_affect(initialize; iv, alg_eqs, discrete_parameters), make_affect(
240-
finalize; iv, alg_eqs, discrete_parameters),
236+
new(conditions, make_affect(affect; kwargs...),
237+
make_affect(affect_neg; kwargs...),
238+
make_affect(initialize; kwargs...), make_affect(
239+
finalize; kwargs...),
241240
rootfind, reinitializealg)
242241
end # Default affect to nothing
243242
end
@@ -247,16 +246,23 @@ function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
247246
end
248247
SymbolicContinuousCallback(cb::SymbolicContinuousCallback, args...; kwargs...) = cb
249248
SymbolicContinuousCallback(cb::Nothing, args...; kwargs...) = nothing
249+
function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
250+
if length(cb) == 2
251+
SymbolicContinuousCallback(cb[1]; kwargs..., cb[2]...)
252+
else
253+
error("Malformed tuple specifying callback. Should be a condition => affect pair, followed by a vector of kwargs.")
254+
end
255+
end
250256

251257
make_affect(affect::Nothing; kwargs...) = nothing
252258
make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
253259
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
254260
make_affect(affect::Affect; kwargs...) = affect
255261

256262
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
257-
iv = nothing, alg_eqs::Vector{Equation} = Equation[])
263+
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
258264
isempty(affect) && return nothing
259-
isempty(alg_eqs) &&
265+
isempty(alg_eqs) && warn_no_algebraic &&
260266
@warn "No algebraic equations were found for the callback defined by $(join(affect, ", ")). If the system has no algebraic equations, this can be disregarded. Otherwise pass in `alg_eqs` to the SymbolicContinuousCallback constructor."
261267
if isnothing(iv)
262268
iv = t_nounits
@@ -423,7 +429,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
423429
function SymbolicDiscreteCallback(
424430
condition, affect = nothing;
425431
initialize = nothing, finalize = nothing, iv = nothing,
426-
alg_eqs = Equation[], discrete_parameters = Any[], reinitializealg = nothing)
432+
reinitializealg = nothing, kwargs...)
427433
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
428434

429435
if isnothing(reinitializealg)
@@ -432,9 +438,9 @@ struct SymbolicDiscreteCallback <: AbstractCallback
432438
reinitializealg = SciMLBase.CheckInit() :
433439
reinitializealg = SciMLBase.NoInit()
434440
end
435-
new(c, make_affect(affect; iv, alg_eqs, discrete_parameters),
436-
make_affect(initialize; iv, alg_eqs, discrete_parameters),
437-
make_affect(finalize; iv, alg_eqs, discrete_parameters), reinitializealg)
441+
new(c, make_affect(affect; kwargs...),
442+
make_affect(initialize; kwargs...),
443+
make_affect(finalize; kwargs...), reinitializealg)
438444
end # Default affect to nothing
439445
end
440446

@@ -443,6 +449,13 @@ function SymbolicDiscreteCallback(p::Pair, args...; kwargs...)
443449
end
444450
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback, args...; kwargs...) = cb
445451
SymbolicDiscreteCallback(cb::Nothing, args...; kwargs...) = nothing
452+
function SymbolicDiscreteCallback(cb::Tuple, args...; kwargs...)
453+
if length(cb) == 2
454+
SymbolicDiscreteCallback(cb[1]; cb[2]...)
455+
else
456+
error("Malformed tuple specifying callback. Should be a condition => affect pair, followed by a vector of kwargs.")
457+
end
458+
end
446459

447460
function is_timed_condition(condition::T) where {T}
448461
if T === Num
@@ -861,7 +874,7 @@ Compile an affect defined by a set of equations. Systems with algebraic equation
861874
function compile_equational_affect(
862875
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false, kwargs...)
863876
if aff isa AbstractVector
864-
aff = make_affect(aff; iv = get_iv(sys))
877+
aff = make_affect(aff; iv = get_iv(sys), warn_no_algebraic = false)
865878
end
866879
affsys = system(aff)
867880
ps_to_update = discretes(aff)

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,9 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
321321
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
322322
deqs)
323323
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
324-
continuous_events; alg_eqs, iv))
325-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(discrete_events; alg_eqs, iv))
324+
continuous_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
325+
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
326+
discrete_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
326327

327328
if is_dde === nothing
328329
is_dde = _check_if_dde(deqs, iv′, systems)

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
273273
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
274274
deqs)
275275
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
276-
continuous_events; alg_eqs, iv))
277-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(discrete_events; alg_eqs, iv))
276+
continuous_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
277+
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
278+
discrete_events; alg_eqs = alg_eqs, iv = iv, warn_no_algebraic = false))
278279

279280
if is_dde === nothing
280281
is_dde = _check_if_dde(deqs, iv′, systems)

src/systems/jumps/jumpsystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ function JumpSystem(eqs, iv, unknowns, ps;
212212
end
213213
end
214214

215-
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(discrete_events; iv))
216-
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(continuous_events; iv))
215+
disc_callbacks = to_cb_vector(SymbolicDiscreteCallback.(
216+
discrete_events; iv = iv, warn_no_algebraic = false))
217+
cont_callbacks = to_cb_vector(SymbolicContinuousCallback.(
218+
continuous_events; iv = iv, warn_no_algebraic = false))
217219

218220
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
219221
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,

src/systems/model_parsing.jl

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
127127

128128
sys = :($ODESystem($(flatten_equations)(equations), $iv, variables, parameters;
129129
name, description = $description, systems, gui_metadata = $gui_metadata,
130+
continuous_events = [$(c_evts...)], discrete_events = [$(d_evts...)],
130131
defaults))
131132

132133
if length(ext) == 0
@@ -138,25 +139,6 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
138139
isconnector && push!(exprs.args,
139140
:($Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))))
140141

141-
if !isempty(d_evts) || !isempty(c_evts)
142-
push!(exprs.args, :(alg_eqs = $(alg_equations)(var"#___sys___")))
143-
!isempty(d_evts) && begin
144-
d_exprs = [:($(SymbolicDiscreteCallback)(
145-
$(evt.args[1]); iv = $iv, alg_eqs, $(evt.args[2])...))
146-
for evt in d_evts]
147-
push!(exprs.args,
148-
:($Setfield.@set!(var"#___sys___".discrete_events=[$(d_exprs...)])))
149-
end
150-
151-
!isempty(c_evts) && begin
152-
c_exprs = [:($(SymbolicContinuousCallback)(
153-
$(evt.args[1]); iv = $iv, alg_eqs, $(evt.args[2])...))
154-
for evt in c_evts]
155-
push!(exprs.args,
156-
:($Setfield.@set!(var"#___sys___".continuous_events=[$(c_exprs...)])))
157-
end
158-
end
159-
160142
f = if length(where_types) == 0
161143
:($(Symbol(:__, name, :__))(; name, $(kwargs...)) = $exprs)
162144
else

test/extensions/ad.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ end
5959
@parameters a b[1:3] c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
6060
@named sys = ODESystem(
6161
Equation[], t, [], [a, b, c, d, e, f, g, h],
62-
continuous_events = [[a ~ 0] => [c ~ 0]])
62+
continuous_events = [ModelingToolkit.SymbolicContinuousCallback(
63+
[a ~ 0] => [c ~ 0], discrete_parameters = c)])
6364
sys = complete(sys)
6465

6566
ivs = Dict(c => 3a, b => ones(3), a => 1.0, d => 4, e => [5.0, 6.0, 7.0],

test/jumpsystem.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function getmean(jprob, Nsims; use_stepper = true)
8080
end
8181
m / Nsims
8282
end
83-
@btime m = $getmean($jprob, $Nsims)
83+
m = getmean(jprob, Nsims)
8484

8585
# test auto-alg selection works
8686
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)
@@ -248,7 +248,7 @@ end
248248
rate = k
249249
affect = [X ~ X - 1]
250250

251-
crj = ConstantRateJump(1.0, [X ~ X - 1])
251+
crj = ConstantRateJump(1.0, [X ~ Pre(X) - 1])
252252
js1 = complete(JumpSystem([crj], t, [X], [k]; name = :js1))
253253
js2 = complete(JumpSystem([crj], t, [X], []; name = :js2))
254254

@@ -275,18 +275,18 @@ dp4 = DiscreteProblem(js4, u0, tspan)
275275
@parameters k
276276
@variables X(t)
277277
rate = k
278-
affect = [X ~ X - 1]
278+
affect = [X ~ Pre(X) - 1]
279279

280-
j1 = ConstantRateJump(k, [X ~ X - 1])
280+
j1 = ConstantRateJump(k, [X ~ Pre(X) - 1])
281281
@test_nowarn @mtkbuild js1 = JumpSystem([j1], t, [X], [k])
282282

283283
# test correct autosolver is selected, which implies appropriate dep graphs are available
284284
let
285285
@parameters k
286286
@variables X(t)
287287
rate = k
288-
affect = [X ~ X - 1]
289-
j1 = ConstantRateJump(k, [X ~ X - 1])
288+
affect = [X ~ Pre(X) - 1]
289+
j1 = ConstantRateJump(k, [X ~ Pre(X) - 1])
290290

291291
Nv = [1, JumpProcesses.USE_DIRECT_THRESHOLD + 1, JumpProcesses.USE_RSSA_THRESHOLD + 1]
292292
algtypes = [Direct, RSSA, RSSACR]
@@ -305,7 +305,7 @@ let
305305
Random.seed!(rng, 1111)
306306
@variables A(t) B(t) C(t)
307307
@parameters k
308-
vrj = VariableRateJump(k * (sin(t) + 1), [A ~ A + 1, C ~ C + 2])
308+
vrj = VariableRateJump(k * (sin(t) + 1), [A ~ Pre(A) + 1, C ~ Pre(C) + 2])
309309
js = complete(JumpSystem([vrj], t, [A, C], [k]; name = :js, observed = [B ~ C * A]))
310310
oprob = ODEProblem(js, [A => 0, C => 0], (0.0, 10.0), [k => 1.0])
311311
jprob = JumpProblem(js, oprob, Direct(); rng)
@@ -346,9 +346,9 @@ end
346346
let
347347
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
348348
@parameters p1 p2 p3 p4 p5
349-
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
349+
j1 = ConstantRateJump(p1, [x1 ~ Pre(x1) + 1])
350350
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
351-
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
351+
j3 = VariableRateJump(p3, [x3 ~ Pre(x3) + 1, x4 ~ Pre(x4) + 1])
352352
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
353353
us = Set()
354354
ps = Set()
@@ -390,9 +390,9 @@ let
390390
p4 = DelayParentScope(p4)
391391
p5 = GlobalScope(p5)
392392

393-
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
393+
j1 = ConstantRateJump(p1, [x1 ~ Pre(x1) + 1])
394394
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
395-
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
395+
j3 = VariableRateJump(p3, [x3 ~ Pre(x3) + 1, x4 ~ Pre(x4) + 1])
396396
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
397397
@named js = JumpSystem([j1, j2, j3, j4], t, [x1, x2, x3, x4, x5], [p1, p2, p3, p4, p5])
398398

@@ -430,8 +430,8 @@ let
430430
Random.seed!(rng, seed)
431431
@variables X(t) Y(t)
432432
@parameters k1 k2
433-
vrj1 = VariableRateJump(k1 * X, [X ~ X - 1]; save_positions = (false, false))
434-
vrj2 = VariableRateJump(k1, [Y ~ Y + 1]; save_positions = (false, false))
433+
vrj1 = VariableRateJump(k1 * X, [X ~ Pre(X) - 1]; save_positions = (false, false))
434+
vrj2 = VariableRateJump(k1, [Y ~ Pre(Y) + 1]; save_positions = (false, false))
435435
eqs = [D(X) ~ k2, D(Y) ~ -k2 / 10 * Y]
436436
@named jsys = JumpSystem([vrj1, vrj2, eqs[1], eqs[2]], t, [X, Y], [k1, k2])
437437
jsys = complete(jsys)
@@ -472,8 +472,8 @@ let
472472
Random.seed!(rng, seed)
473473
@variables X(t) Y(t)
474474
@parameters α β
475-
vrj = VariableRateJump* X, [X ~ X - 1]; save_positions = (false, false))
476-
crj = ConstantRateJump* Y, [Y ~ Y - 1])
475+
vrj = VariableRateJump* X, [X ~ Pre(X) - 1]; save_positions = (false, false))
476+
crj = ConstantRateJump* Y, [Y ~ Pre(Y) - 1])
477477
maj = MassActionJump(α, [0 => 1], [Y => 1])
478478
eqs = [D(X) ~ α * (1 + Y)]
479479
@named jsys = JumpSystem([maj, crj, vrj, eqs[1]], t, [X, Y], [α, β])
@@ -540,8 +540,8 @@ end
540540
@variables X(t)
541541
rate1 = p
542542
rate2 = X * d
543-
affect1 = [X ~ X + 1]
544-
affect2 = [X ~ X - 1]
543+
affect1 = [X ~ Pre(X) + 1]
544+
affect2 = [X ~ Pre(X) - 1]
545545
j1 = ConstantRateJump(rate1, affect1)
546546
j2 = ConstantRateJump(rate2, affect2)
547547

test/mtkparameters.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using JET
1010
@parameters a b c(t) d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
1111
@named sys = ODESystem(
1212
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b ~ 2a],
13-
continuous_events = [[a ~ 0] => [c ~ 0]], defaults = Dict(a => 0.0))
13+
continuous_events = [ModelingToolkit.SymbolicContinuousCallback(
14+
[a ~ 0] => [c ~ 0], discrete_parameters = c)], defaults = Dict(a => 0.0))
1415
sys = complete(sys)
1516

1617
ivs = Dict(c => 3a, d => 4, e => [5.0, 6.0, 7.0],

0 commit comments

Comments
 (0)