Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct SymbolicAffect
end

function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[],
discrete_parameters = Any[], kwargs...)
discrete_parameters = infer_discrete_parameters(affect), kwargs...)
if !(discrete_parameters isa AbstractVector)
discrete_parameters = Any[discrete_parameters]
elseif !(discrete_parameters isa Vector{Any})
Expand All @@ -31,6 +31,38 @@ function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
map(substituter, aff.discrete_parameters))
end

# The discrete parameters (i.e. parameters that are updated in an event) can be inferred as
# those that occur in an affect equation *outside* of a `Pre(...)` operator.
function infer_discrete_parameters(affects)
discrete_parameters = Set()
for affect in affects
if affect isa Equation
infer_discrete_parameters!(discrete_parameters, affect.lhs)
infer_discrete_parameters!(discrete_parameters, affect.rhs)
elseif affect isa NamedTuple
haskey(affect, :modified) && union!(discrete_parameters, affect.modified)
end
end
return collect(discrete_parameters)
end

# Find all `expr`'s parameters that occur *outside* of a Pre(...) statement. Add these to `discrete_parameters`.
function infer_discrete_parameters!(discrete_parameters, expr)
expr_pre_removed = Symbolics.replacenode(expr, precall_to_1)
dynamic_symvars = Symbolics.get_variables(expr_pre_removed)
# Change this coming line to a Symbolic append type of thing.
union!(discrete_parameters, filter(ModelingToolkit.isparameter, dynamic_symvars))
end

# When updating vector variables, the affect side can be a vector.
function infer_discrete_parameters!(discrete_parameters, expr_vec::Vector)
foreach(expr -> infer_discrete_parameters!(discrete_parameters, expr), expr_vec)
end

# Functions for replacing a Pre-call with a `1.0` (removing its content from an expression).
is_precall(expr) = iscall(expr) ? operation(expr) isa Pre : false
precall_to_1(expr) = (is_precall(expr) ? 1.0 : expr)

struct AffectSystem
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
system::AbstractSystem
Expand Down
132 changes: 131 additions & 1 deletion test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ end
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
@test sol[x, end]≈1.0 atol=1e-6
@test sol[x, end]≈0.75 atol=1e-6
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an error in this test. The discrete_parameters argument was forgotten, so the event is not actually triggered. Unless the intention of the test is to check that everything works as (un)expected when the input is omitted (but the naming and test do not suggest that), this should be changed. With this PR, the event is now triggered, and the correct value 0.75 is achieved.

end

@testset "Symbolic affects are compiled in `complete`" begin
Expand Down Expand Up @@ -1440,3 +1440,133 @@ end
@mtkcompile sys = MWE()
@test_nowarn ODEProblem(sys, [], (0.0, 1.0))
end

@testset "Automatic inference of `discrete_parameters`" begin
# Basic case, checks for both types of events (in combination and isolation).
let
# Creates models with continuous, discrete, or both types of events
@variables X(t) Y(t)
@parameters p1(t) p2(t) d1 d2
eqs = [
D(X) ~ p1 - d1*X,
D(Y) ~ p2 - d2*Y
]
cevent = [[t ~ 1.0] => [p1 ~ 2 * Pre(p1)]]
devent = [[1.0] => [p2 ~ 2 * Pre(p2)]]
@mtkcompile sys_c = System(eqs, t; continuous_events = cevent)
@mtkcompile sys_d = System(eqs, t; discrete_events = devent)
@mtkcompile sys_cd = System(eqs, t; discrete_events = devent, continuous_events = cevent)

# Simulates them all. They should start at steady states (X = Y = 1).
# If event triggers, the variable should become 2 (after some wait).
sim_cond = [X => 1.0, Y => 1.0, p1 => 1.0, p2 => 1.0, d1 => 1.0, d2 => 1.0]
prob_c = ODEProblem(sys_c, sim_cond, 100.0)
prob_d = ODEProblem(sys_d, sim_cond, 100.0)
prob_cd = ODEProblem(sys_cd, sim_cond, 100.0)
sol_c = solve(prob_c, Rosenbrock23())
sol_d = solve(prob_d, Rosenbrock23())
sol_cd = solve(prob_cd, Rosenbrock23())
@test sol_c[X][end]≈2.0 atol=1e-3 rtol=1e-3
@test sol_c[Y][end]≈1.0 atol=1e-3 rtol=1e-3
@test sol_d[X][end]≈1.0 atol=1e-3 rtol=1e-3
@test sol_d[Y][end]≈2.0 atol=1e-3 rtol=1e-3
@test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3
@test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3
end

# Complicated and multiple events.
# All should trigger. Modified parameters (and only those) should get non-zero values.
let
# Declares the model. `k` parameters depend on time, but should not actually be updated.
us = @variables X1(t) X2(t) X3(t) X4(t) X5(t)
ps = @parameters p1(t) p2(t) p3(t) p4(t) p5(t) k1(t) k2(t) k3(t) k4(t) k5(t) d1 d2 d3 d4 d5
eqs = [
D(X1) ~ p1 + k1 - d1*X1,
D(X2) ~ p2 + k2 - d2*X2,
D(X3) ~ p3 + k3 - d3*X3,
D(X4) ~ p4 + k4 - d4*X4,
D(X5) ~ p4 + k4 - d4*X5
]
cevents = [[t + d1 + k1 ~
0.5] => [Pre(X1)*(p1 + 5 + Pre(X2)) + Pre(k1) ~ Pre(3X2 + k2)]]
devents = [
2.0 => [exp(p2 + Pre(p2)) ~ 5.0],
[1.0] => [(4 + Pre(k2) + Pre(k4) + Pre(k3))^3 + exp(1 + Pre(k3)) ~
(3 + p3 + Pre(k2))^3],
(t ==
1.5) => [
Pre(k2) + Pre(k3) ~ p4 * (2 + Pre(k1)) + 3,
Pre(p5) + 2 + 3Pre(k4) + Pre(p5) ~ exp(p5)
]
]
@mtkcompile sys = System(
eqs, t, us, ps; continuous_events = cevents, discrete_events = devents)

# Simulates system so that all events trigger.
sim_cond = [
X1 => 1.0, X2 => 1.0, X3 => 1.0, X4 => 1.0, X5 => 1.0,
p1 => 0.0, p2 => 0.0, p3 => 0.0, p4 => 0.0, p5 => 0.0,
k1 => 0.0, k2 => 0.0, k3 => 0.0, k4 => 0.0, k5 => 0.0,
d1 => 0.0, d2 => 0.0, d3 => 0.0, d4 => 0.0, d5 => 0.0
]
prob = ODEProblem(sys, sim_cond, 3.0)
sol = solve(prob, tstops = [1.5])

# Check that the correct parameters have been modified.
@test sol.ps[p1][end] != 0.0
@test sol.ps[p2][end] != 0.0
@test sol.ps[p3][end] != 0.0
@test sol.ps[p4][end] != 0.0
@test sol.ps[p5][end] != 0.0
@test sol.ps[k1] == sol.ps[k2] == sol.ps[k3] == sol.ps[k4] == sol.ps[k5] == 0.0
@test sol.ps[d1] == sol.ps[d2] == sol.ps[d3] == sol.ps[d4] == sol.ps[d5] == 0.0
end

# Checks that everything works for vector-valued parameters and variables.
let
# Creates the model using the macro.
@mtkmodel VectorParams begin
@parameters begin
k(t)[1:2] = [1, 1]
kup = 2.0
end
@variables begin
X(t)[1:2] = [4.0, 4.0]
end
@equations begin
D(X[1]) ~ -k[1]*X[1] + k[2]*X[2]
D(X[2]) ~ k[1]*X[1] - k[2]*X[2]
end
@continuous_events begin
(k[2] ~ t) => [k[1] ~ Pre(k[1] + kup)]
end
end
@mtkcompile model = VectorParams()

# Simulates the model. Checks that the correct values are achieved.
prob = ODEProblem(model, [], (0.0, 100.0))
sol = solve(prob, Rosenbrock23())
@test sol.ps[model.kup] == 2.0
@test sol.ps[model.k[1]] == 3.0
@test sol.ps[model.k[2]] == 1.0
@test sol[model.X[1]][end]≈2.0 atol=1e-8 rtol=1e-8
@test sol[model.X[2]][end]≈6.0 atol=1e-8 rtol=1e-8
end

# Checks for a functional affect.
let
# Creates model.
@variables X(t) = 5.0
@parameters p=2.0 d(t)=1.0
eqs = [D(X) ~ p - d * X]
affect!(mod, obs, ctx, integ) = return (; d = 2.0)
cevent = [t ~ 1.0] => (f = affect!, modified = (; d))
@mtkcompile sys = System(eqs, t; continuous_events = [cevent])

# Simualtes the model and checks that values is correct.
sol = solve(ODEProblem(sys, [], (0.0, 100.0)), Rosenbrock23())
@test sol[X][end]≈1.0 atol=1e-8 rtol=1e-8
@test sol.ps[p] == 2.0
@test sol.ps[d] == [1.0, 2.0]
end
end
Loading