Skip to content

Commit 1861727

Browse files
refactor: use process_SciMLProblem in jumpsystem.jl
1 parent 03e294f commit 1861727

File tree

4 files changed

+29
-37
lines changed

4 files changed

+29
-37
lines changed

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2914,7 +2914,7 @@ function Base.eltype(::Type{<:TreeIterator{ModelingToolkit.AbstractSystem}})
29142914
end
29152915

29162916
function check_array_equations_unknowns(eqs, dvs)
2917-
if any(eq -> Symbolics.isarraysymbolic(eq.lhs), eqs)
2917+
if any(eq -> eq isa Equation && Symbolics.isarraysymbolic(eq.lhs), eqs)
29182918
throw(ArgumentError("The system has array equations. Call `structural_simplify` to handle such equations or scalarize them manually."))
29192919
end
29202920
if any(x -> Symbolics.isarraysymbolic(x), dvs)

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
348348
if !iscomplete(sys)
349349
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
350350
end
351-
dvs = unknowns(sys)
352-
ps = parameters(sys)
353-
354-
defs = defaults(sys)
355-
defs = mergedefaults(defs, parammap, ps)
356-
defs = mergedefaults(defs, u0map, dvs)
357-
358-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
359-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
360-
p = MTKParameters(sys, parammap, u0map)
361-
else
362-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
363-
end
364-
351+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
352+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
365353
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
366354

367355
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
@@ -399,16 +387,9 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
399387
if !iscomplete(sys)
400388
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
401389
end
402-
dvs = unknowns(sys)
403-
ps = parameters(sys)
404-
defs = defaults(sys)
405390

406-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
407-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
408-
p = MTKParameters(sys, parammap, u0map)
409-
else
410-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
411-
end
391+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
392+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
412393
# identity function to make syms works
413394
quote
414395
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
@@ -454,19 +435,9 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
454435
if !iscomplete(sys)
455436
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
456437
end
457-
dvs = unknowns(sys)
458-
ps = parameters(sys)
459-
460-
defs = defaults(sys)
461-
defs = mergedefaults(defs, parammap, ps)
462-
defs = mergedefaults(defs, u0map, dvs)
463438

464-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
465-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
466-
p = MTKParameters(sys, parammap, u0map)
467-
else
468-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
469-
end
439+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
440+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
470441

471442
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
472443

src/systems/problem_utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,18 @@ function get_temporary_value(p)
332332
end
333333
end
334334

335+
"""
336+
$(TYPEDEF)
337+
338+
A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in
339+
case constructing a SciMLFunction is not required.
340+
"""
341+
struct EmptySciMLFunction end
342+
343+
function EmptySciMLFunction(args...; kwargs...)
344+
return nothing
345+
end
346+
335347
"""
336348
$(TYPEDSIGNATURES)
337349

src/utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,15 @@ function collect_constants!(constants, expr::Symbolic)
610610
end
611611
end
612612

613+
function collect_constants!(constants, expr::Union{ConstantRateJump, VariableRateJump})
614+
collect_constants!(constants, expr.rate)
615+
collect_constants!(constants, expr.affect!)
616+
end
617+
618+
function collect_constants!(constants, ::MassActionJump)
619+
return constants
620+
end
621+
613622
"""
614623
Replace symbolic constants with their literal values
615624
"""
@@ -667,7 +676,7 @@ end
667676

668677
function get_cmap(sys, exprs = nothing)
669678
#Inject substitutions for constants => values
670-
cs = collect_constants([get_eqs(sys); get_observed(sys)]) #ctrls? what else?
679+
cs = collect_constants([collect(get_eqs(sys)); get_observed(sys)]) #ctrls? what else?
671680
if !empty_substitutions(sys)
672681
cs = [cs; collect_constants(get_substitutions(sys).subs)]
673682
end

0 commit comments

Comments
 (0)