From f944212d5a6708876efe1e68a8481ba5ea566906 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 13:45:20 +0000 Subject: [PATCH 1/6] infer `discrete_parameters` in all events --- src/systems/callbacks.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c94166103b..9c6a577a0e 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -11,7 +11,7 @@ struct SymbolicAffect end function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[], - discrete_parameters = Any[], kwargs...) + discrete_parameters = infer_discrete_parameters(affect), kwargs...) if !(discrete_parameters isa AbstractVector) discrete_parameters = Any[discrete_parameters] elseif !(discrete_parameters isa Vector{Any}) @@ -31,6 +31,28 @@ function Symbolics.fast_substitute(aff::SymbolicAffect, rules) map(substituter, aff.discrete_parameters)) end +# The discrete parameters (i.e. parameters that are updated in an event) can be inferred as +# those that occur in an affect equation *outside* of a `Pre(...)` operator. +function infer_discrete_parameters(affects) + discrete_parameters = Set() + for affect in affects + infer_discrete_parameters!(discrete_parameters, affect.lhs) + infer_discrete_parameters!(discrete_parameters, affect.rhs) + end + return collect(discrete_parameters) +end + +# Find all `expr`'s parameters that occur *outside* of a Pre(...) statement. Add these to `discrete_parameters`. +function infer_discrete_parameters!(discrete_parameters, expr) + expr_pre_removed = Symbolics.replacenode(expr, precall_to_1) + dynamic_symvars = Symbolics.get_variables(expr_pre_removed) + # Change this coming line to a Symbolic append type of thing. + union!(discrete_parameters, filter(ModelingToolkit.isparameter, dynamic_symvars)) +end +# Functions for replacing a Pre-call with a `1.0` (removing its content from an expression). +is_precall(expr) = iscall(expr) ? operation(expr) isa Pre : false +precall_to_1(expr) = (is_precall(expr) ? 1.0 : expr) + struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem From 074cde5e4b420e2639e5ceb738e2638c16012393 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 15:25:43 +0000 Subject: [PATCH 2/6] add tests --- test/symbolic_events.jl | 80 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 5f37e524e1..7729098708 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1440,3 +1440,83 @@ end @mtkcompile sys = MWE() @test_nowarn ODEProblem(sys, [], (0.0, 1.0)) end + + +@testset "Automatic inference of `discrete_parameters`" begin + + # Basic case, checks for both types of events (in combination and isolation). + let + # Creates models with continuous, discrete, or both types of events + @variables X(t) Y(t) + @parameters p1(t) p2(t) d1 d2 + eqs = [ + D(X) ~ p1 - d1*X, + D(Y) ~ p2 - d2*Y + ] + cevent = [[t ~ 1.0] => [p1 ~ 2 * Pre(p1)]] + devent = [[1.0] => [p2 ~ 2 * Pre(p2)]] + @mtkcompile sys_c = System(eqs, t; continuous_events = cevent) + @mtkcompile sys_d = System(eqs, t; discrete_events = devent) + @mtkcompile sys_cd = System(eqs, t; discrete_events = devent, continuous_events = cevent) + + # Simulates them all. They should start at steady states (X = Y = 1). + # If event triggers, the variable should become 2 (after some wait). + sim_cond = [X => 1.0, Y => 1.0, p1 => 1.0, p2 => 1.0, d1 => 1.0, d2 => 1.0] + prob_c = ODEProblem(sys_c, sim_cond, 100.0) + prob_d = ODEProblem(sys_d, sim_cond, 100.0) + prob_cd = ODEProblem(sys_cd, sim_cond, 100.0) + sol_c = solve(prob_c, Rosenbrock23()) + sol_d = solve(prob_d, Rosenbrock23()) + sol_cd = solve(prob_cd, Rosenbrock23()) + @test sol_c[X][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 + @test sol_c[Y][end] ≈ 1.0 atol = 1e-3 rtol = 1e-3 + @test sol_d[X][end] ≈ 1.0 atol = 1e-3 rtol = 1e-3 + @test sol_d[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 + @test sol_cd[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 + @test sol_cd[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 + end + + # Complicated and multiple events. + # All should trigger. Modified parameters (and only those) should get non-zero values. + let + # Declares the model. `k` parameters depend on time, but should not actually be updated. + us = @variables X1(t) X2(t) X3(t) X4(t) X5(t) + ps = @parameters p1(t) p2(t) p3(t) p4(t) p5(t) k1(t) k2(t) k3(t) k4(t) k5(t) d1 d2 d3 d4 d5 + eqs = [ + D(X1) ~ p1 + k1 - d1*X1, + D(X2) ~ p2 + k2 - d2*X2, + D(X3) ~ p3 + k3 - d3*X3, + D(X4) ~ p4 + k4 - d4*X4, + D(X5) ~ p4 + k4 - d4*X5, + ] + cevents = [[t + d1 + k1 ~ 0.5] => [Pre(X1)*(p1 + 5 + Pre(X2)) + Pre(k1) ~ Pre(3X2 + k2)]] + devents = [ + 2.0 => [exp(p2 + Pre(p2)) ~ 5.0], + [1.0] => [(4 + Pre(k2) + Pre(k4) + Pre(k3))^3 + exp(1 + Pre(k3)) ~ (3 + p3 + Pre(k2))^3], + (t == 1.5) => [ + Pre(k2) + Pre(k3) ~ p4 * (2 + Pre(k1)) + 3, + Pre(p5) + 2 + 3Pre(k4) + Pre(p5) ~ exp(p5) + ] + ] + @mtkcompile sys = System(eqs, t, us, ps; continuous_events = cevents, discrete_events = devents) + + # Simulates system so that all events trigger. + sim_cond = [ + X1 => 1.0, X2 => 1.0, X3 => 1.0, X4 => 1.0, X5 => 1.0, + p1 => 0.0, p2 => 0.0, p3 => 0.0, p4 => 0.0, p5 => 0.0, + k1 => 0.0, k2 => 0.0, k3 => 0.0, k4 => 0.0, k5 => 0.0, + d1 => 0.0, d2 => 0.0, d3 => 0.0, d4 => 0.0, d5 => 0.0 + ] + prob = ODEProblem(sys, sim_cond, 3.0) + sol = solve(prob, tstops = [1.5]) + + # Check that the correct parameters have been modified. + @test sol.ps[p1][end] != 0.0 + @test sol.ps[p2][end] != 0.0 + @test sol.ps[p3][end] != 0.0 + @test sol.ps[p4][end] != 0.0 + @test sol.ps[p5][end] != 0.0 + @test sol.ps[k1] == sol.ps[k2] == sol.ps[k3] == sol.ps[k4] == sol.ps[k5] == 0.0 + @test sol.ps[d1] == sol.ps[d2] == sol.ps[d3] == sol.ps[d4] == sol.ps[d5] == 0.0 + end +end \ No newline at end of file From 1afe8ca00a2297d219e5052d4ab4d428f6459f25 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 19:56:52 +0000 Subject: [PATCH 3/6] more tests --- src/systems/callbacks.jl | 14 +++++++++-- test/symbolic_events.jl | 50 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 9c6a577a0e..1846885f83 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -36,8 +36,12 @@ end function infer_discrete_parameters(affects) discrete_parameters = Set() for affect in affects - infer_discrete_parameters!(discrete_parameters, affect.lhs) - infer_discrete_parameters!(discrete_parameters, affect.rhs) + if affect isa Equation + infer_discrete_parameters!(discrete_parameters, affect.lhs) + infer_discrete_parameters!(discrete_parameters, affect.rhs) + elseif affect isa NamedTuple + haskey(affect, :modified) && union!(discrete_parameters, affect.modified) + end end return collect(discrete_parameters) end @@ -49,6 +53,12 @@ function infer_discrete_parameters!(discrete_parameters, expr) # Change this coming line to a Symbolic append type of thing. union!(discrete_parameters, filter(ModelingToolkit.isparameter, dynamic_symvars)) end + +# When updating vector variables, the affect side can be a vector. +function infer_discrete_parameters!(discrete_parameters, expr_vec::Vector) + foreach(expr -> infer_discrete_parameters!(discrete_parameters, expr), expr_vec) +end + # Functions for replacing a Pre-call with a `1.0` (removing its content from an expression). is_precall(expr) = iscall(expr) ? operation(expr) isa Pre : false precall_to_1(expr) = (is_precall(expr) ? 1.0 : expr) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 7729098708..cf88cc2448 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1441,9 +1441,7 @@ end @test_nowarn ODEProblem(sys, [], (0.0, 1.0)) end - @testset "Automatic inference of `discrete_parameters`" begin - # Basic case, checks for both types of events (in combination and isolation). let # Creates models with continuous, discrete, or both types of events @@ -1519,4 +1517,52 @@ end @test sol.ps[k1] == sol.ps[k2] == sol.ps[k3] == sol.ps[k4] == sol.ps[k5] == 0.0 @test sol.ps[d1] == sol.ps[d2] == sol.ps[d3] == sol.ps[d4] == sol.ps[d5] == 0.0 end + + # Checks that everything works for vector-valued parameters and variables. + let + # Creates the model using the macro. + @mtkmodel VectorParams begin + @parameters begin + k(t)[1:2] = [1, 1] + kup = 2.0 + end + @variables begin + X(t)[1:2] = [4.0, 4.0] + end + @equations begin + D(X[1]) ~ -k[1]*X[1] + k[2]*X[2] + D(X[2]) ~ k[1]*X[1] - k[2]*X[2] + end + @continuous_events begin + (k[2] ~ t) => [k[1] ~ Pre(k[1] + kup)] + end + end + @mtkcompile model = VectorParams() + + # Simulates the model. Checks that the correct values are achieved. + prob = ODEProblem(model, [], (0.0, 100.0)) + sol = solve(prob, Rosenbrock23()) + @test sol.ps[model.kup] == 2.0 + @test sol.ps[model.k[1]] == 3.0 + @test sol.ps[model.k[2]] == 1.0 + @test sol[model.X[1]][end] ≈ 2.0 atol = 1e-8 rtol = 1e-8 + @test sol[model.X[2]][end] ≈ 6.0 atol = 1e-8 rtol = 1e-8 + end + + # Checks for a functional affect. + let + # Creates model. + @variables X(t) = 5.0 + @parameters p = 2.0 d(t) = 1.0 + eqs = [D(X) ~ p - d * X] + affect!(mod, obs, ctx, integ) = return (; d = 2.0) + cevent = [t ~ 1.0] => (f = affect!, modified = (; d)) + @mtkcompile sys = System(eqs, t; continuous_events = [cevent]) + + # Simualtes the model and checks that values is correct. + sol = solve(ODEProblem(sys, [], (0.0, 100.0)), Rosenbrock23()) + @test sol[X][end] ≈ 1.0 atol = 1e-8 rtol = 1e-8 + @test sol.ps[p] == 2.0 + @test sol.ps[d] == [1.0, 2.0] + end end \ No newline at end of file From b14042327836a9c7f6352262ee291d6e9add05c9 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 19:59:08 +0000 Subject: [PATCH 4/6] fix previous erroneous test --- test/symbolic_events.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index cf88cc2448..bc8f34403a 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1385,7 +1385,7 @@ end prob = ODEProblem(sys, [], (0.0, 1.0)) sol = solve(prob) @test SciMLBase.successful_retcode(sol) - @test sol[x, end]≈1.0 atol=1e-6 + @test sol[x, end]≈0.75 atol=1e-6 end @testset "Symbolic affects are compiled in `complete`" begin From 5dc07f2aad37a195b66c43da4eb5bc16d9988818 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 20:37:10 +0000 Subject: [PATCH 5/6] formatting --- test/symbolic_events.jl | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index bc8f34403a..55de7e1122 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1466,12 +1466,12 @@ end sol_c = solve(prob_c, Rosenbrock23()) sol_d = solve(prob_d, Rosenbrock23()) sol_cd = solve(prob_cd, Rosenbrock23()) - @test sol_c[X][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 - @test sol_c[Y][end] ≈ 1.0 atol = 1e-3 rtol = 1e-3 - @test sol_d[X][end] ≈ 1.0 atol = 1e-3 rtol = 1e-3 - @test sol_d[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 - @test sol_cd[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 - @test sol_cd[Y][end] ≈ 2.0 atol = 1e-3 rtol = 1e-3 + @test sol_c[X][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_c[Y][end]≈1.0 atol=1e-3 rtol=1e-3 + @test sol_d[X][end]≈1.0 atol=1e-3 rtol=1e-3 + @test sol_d[Y][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3 end # Complicated and multiple events. @@ -1485,18 +1485,22 @@ end D(X2) ~ p2 + k2 - d2*X2, D(X3) ~ p3 + k3 - d3*X3, D(X4) ~ p4 + k4 - d4*X4, - D(X5) ~ p4 + k4 - d4*X5, + D(X5) ~ p4 + k4 - d4*X5 ] - cevents = [[t + d1 + k1 ~ 0.5] => [Pre(X1)*(p1 + 5 + Pre(X2)) + Pre(k1) ~ Pre(3X2 + k2)]] + cevents = [[t + d1 + k1 ~ + 0.5] => [Pre(X1)*(p1 + 5 + Pre(X2)) + Pre(k1) ~ Pre(3X2 + k2)]] devents = [ 2.0 => [exp(p2 + Pre(p2)) ~ 5.0], - [1.0] => [(4 + Pre(k2) + Pre(k4) + Pre(k3))^3 + exp(1 + Pre(k3)) ~ (3 + p3 + Pre(k2))^3], - (t == 1.5) => [ - Pre(k2) + Pre(k3) ~ p4 * (2 + Pre(k1)) + 3, + [1.0] => [(4 + Pre(k2) + Pre(k4) + Pre(k3))^3 + exp(1 + Pre(k3)) ~ + (3 + p3 + Pre(k2))^3], + (t == + 1.5) => [ + Pre(k2) + Pre(k3) ~ p4 * (2 + Pre(k1)) + 3, Pre(p5) + 2 + 3Pre(k4) + Pre(p5) ~ exp(p5) ] ] - @mtkcompile sys = System(eqs, t, us, ps; continuous_events = cevents, discrete_events = devents) + @mtkcompile sys = System( + eqs, t, us, ps; continuous_events = cevents, discrete_events = devents) # Simulates system so that all events trigger. sim_cond = [ @@ -1542,18 +1546,18 @@ end # Simulates the model. Checks that the correct values are achieved. prob = ODEProblem(model, [], (0.0, 100.0)) sol = solve(prob, Rosenbrock23()) - @test sol.ps[model.kup] == 2.0 + @test sol.ps[model.kup] == 2.0 @test sol.ps[model.k[1]] == 3.0 @test sol.ps[model.k[2]] == 1.0 - @test sol[model.X[1]][end] ≈ 2.0 atol = 1e-8 rtol = 1e-8 - @test sol[model.X[2]][end] ≈ 6.0 atol = 1e-8 rtol = 1e-8 + @test sol[model.X[1]][end]≈2.0 atol=1e-8 rtol=1e-8 + @test sol[model.X[2]][end]≈6.0 atol=1e-8 rtol=1e-8 end # Checks for a functional affect. let # Creates model. @variables X(t) = 5.0 - @parameters p = 2.0 d(t) = 1.0 + @parameters p=2.0 d(t)=1.0 eqs = [D(X) ~ p - d * X] affect!(mod, obs, ctx, integ) = return (; d = 2.0) cevent = [t ~ 1.0] => (f = affect!, modified = (; d)) @@ -1561,8 +1565,8 @@ end # Simualtes the model and checks that values is correct. sol = solve(ODEProblem(sys, [], (0.0, 100.0)), Rosenbrock23()) - @test sol[X][end] ≈ 1.0 atol = 1e-8 rtol = 1e-8 + @test sol[X][end]≈1.0 atol=1e-8 rtol=1e-8 @test sol.ps[p] == 2.0 @test sol.ps[d] == [1.0, 2.0] end -end \ No newline at end of file +end From aa40b5cd47d388a0c81234c38b9fcacd5ce0d391 Mon Sep 17 00:00:00 2001 From: Torkel Loman Date: Sun, 26 Oct 2025 21:58:09 +0000 Subject: [PATCH 6/6] CI fix --- test/symbolic_events.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 55de7e1122..f755794826 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1563,7 +1563,7 @@ end cevent = [t ~ 1.0] => (f = affect!, modified = (; d)) @mtkcompile sys = System(eqs, t; continuous_events = [cevent]) - # Simualtes the model and checks that values is correct. + # Simulates the model and checks that values is correct. sol = solve(ODEProblem(sys, [], (0.0, 100.0)), Rosenbrock23()) @test sol[X][end]≈1.0 atol=1e-8 rtol=1e-8 @test sol.ps[p] == 2.0