diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c94166103b..1846885f83 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -11,7 +11,7 @@ struct SymbolicAffect end function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[], - discrete_parameters = Any[], kwargs...) + discrete_parameters = infer_discrete_parameters(affect), kwargs...) if !(discrete_parameters isa AbstractVector) discrete_parameters = Any[discrete_parameters] elseif !(discrete_parameters isa Vector{Any}) @@ -31,6 +31,38 @@ function Symbolics.fast_substitute(aff::SymbolicAffect, rules) map(substituter, aff.discrete_parameters)) end +# The discrete parameters (i.e. parameters that are updated in an event) can be inferred as +# those that occur in an affect equation *outside* of a `Pre(...)` operator. +function infer_discrete_parameters(affects) + discrete_parameters = Set() + for affect in affects + if affect isa Equation + infer_discrete_parameters!(discrete_parameters, affect.lhs) + infer_discrete_parameters!(discrete_parameters, affect.rhs) + elseif affect isa NamedTuple + haskey(affect, :modified) && union!(discrete_parameters, affect.modified) + end + end + return collect(discrete_parameters) +end + +# Find all `expr`'s parameters that occur *outside* of a Pre(...) statement. Add these to `discrete_parameters`. +function infer_discrete_parameters!(discrete_parameters, expr) + expr_pre_removed = Symbolics.replacenode(expr, precall_to_1) + dynamic_symvars = Symbolics.get_variables(expr_pre_removed) + # Change this coming line to a Symbolic append type of thing. + union!(discrete_parameters, filter(ModelingToolkit.isparameter, dynamic_symvars)) +end + +# When updating vector variables, the affect side can be a vector. +function infer_discrete_parameters!(discrete_parameters, expr_vec::Vector) + foreach(expr -> infer_discrete_parameters!(discrete_parameters, expr), expr_vec) +end + +# Functions for replacing a Pre-call with a `1.0` (removing its content from an expression). +is_precall(expr) = iscall(expr) ? operation(expr) isa Pre : false +precall_to_1(expr) = (is_precall(expr) ? 1.0 : expr) + struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 5f37e524e1..f755794826 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1385,7 +1385,7 @@ end prob = ODEProblem(sys, [], (0.0, 1.0)) sol = solve(prob) @test SciMLBase.successful_retcode(sol) - @test sol[x, end]≈1.0 atol=1e-6 + @test sol[x, end]≈0.75 atol=1e-6 end @testset "Symbolic affects are compiled in `complete`" begin @@ -1440,3 +1440,133 @@ end @mtkcompile sys = MWE() @test_nowarn ODEProblem(sys, [], (0.0, 1.0)) end + +@testset "Automatic inference of `discrete_parameters`" begin + # Basic case, checks for both types of events (in combination and isolation). + let + # Creates models with continuous, discrete, or both types of events + @variables X(t) Y(t) + @parameters p1(t) p2(t) d1 d2 + eqs = [ + D(X) ~ p1 - d1*X, + D(Y) ~ p2 - d2*Y + ] + cevent = [[t ~ 1.0] => [p1 ~ 2 * Pre(p1)]] + devent = [[1.0] => [p2 ~ 2 * Pre(p2)]] + @mtkcompile sys_c = System(eqs, t; continuous_events = cevent) + @mtkcompile sys_d = System(eqs, t; discrete_events = devent) + @mtkcompile sys_cd = System(eqs, t; discrete_events = devent, continuous_events = cevent) + + # Simulates them all. They should start at steady states (X = Y = 1). + # If event triggers, the variable should become 2 (after some wait). + sim_cond = [X => 1.0, Y => 1.0, p1 => 1.0, p2 => 1.0, d1 => 1.0, d2 => 1.0] + prob_c = ODEProblem(sys_c, sim_cond, 100.0) + prob_d = ODEProblem(sys_d, sim_cond, 100.0) + prob_cd = ODEProblem(sys_cd, sim_cond, 100.0) + sol_c = solve(prob_c, Rosenbrock23()) + sol_d = solve(prob_d, Rosenbrock23()) + sol_cd = solve(prob_cd, Rosenbrock23()) + @test sol_c[X][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_c[Y][end]≈1.0 atol=1e-3 rtol=1e-3 + @test sol_d[X][end]≈1.0 atol=1e-3 rtol=1e-3 + @test sol_d[Y][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3 + @test sol_cd[Y][end]≈2.0 atol=1e-3 rtol=1e-3 + end + + # Complicated and multiple events. + # All should trigger. Modified parameters (and only those) should get non-zero values. + let + # Declares the model. `k` parameters depend on time, but should not actually be updated. + us = @variables X1(t) X2(t) X3(t) X4(t) X5(t) + ps = @parameters p1(t) p2(t) p3(t) p4(t) p5(t) k1(t) k2(t) k3(t) k4(t) k5(t) d1 d2 d3 d4 d5 + eqs = [ + D(X1) ~ p1 + k1 - d1*X1, + D(X2) ~ p2 + k2 - d2*X2, + D(X3) ~ p3 + k3 - d3*X3, + D(X4) ~ p4 + k4 - d4*X4, + D(X5) ~ p4 + k4 - d4*X5 + ] + cevents = [[t + d1 + k1 ~ + 0.5] => [Pre(X1)*(p1 + 5 + Pre(X2)) + Pre(k1) ~ Pre(3X2 + k2)]] + devents = [ + 2.0 => [exp(p2 + Pre(p2)) ~ 5.0], + [1.0] => [(4 + Pre(k2) + Pre(k4) + Pre(k3))^3 + exp(1 + Pre(k3)) ~ + (3 + p3 + Pre(k2))^3], + (t == + 1.5) => [ + Pre(k2) + Pre(k3) ~ p4 * (2 + Pre(k1)) + 3, + Pre(p5) + 2 + 3Pre(k4) + Pre(p5) ~ exp(p5) + ] + ] + @mtkcompile sys = System( + eqs, t, us, ps; continuous_events = cevents, discrete_events = devents) + + # Simulates system so that all events trigger. + sim_cond = [ + X1 => 1.0, X2 => 1.0, X3 => 1.0, X4 => 1.0, X5 => 1.0, + p1 => 0.0, p2 => 0.0, p3 => 0.0, p4 => 0.0, p5 => 0.0, + k1 => 0.0, k2 => 0.0, k3 => 0.0, k4 => 0.0, k5 => 0.0, + d1 => 0.0, d2 => 0.0, d3 => 0.0, d4 => 0.0, d5 => 0.0 + ] + prob = ODEProblem(sys, sim_cond, 3.0) + sol = solve(prob, tstops = [1.5]) + + # Check that the correct parameters have been modified. + @test sol.ps[p1][end] != 0.0 + @test sol.ps[p2][end] != 0.0 + @test sol.ps[p3][end] != 0.0 + @test sol.ps[p4][end] != 0.0 + @test sol.ps[p5][end] != 0.0 + @test sol.ps[k1] == sol.ps[k2] == sol.ps[k3] == sol.ps[k4] == sol.ps[k5] == 0.0 + @test sol.ps[d1] == sol.ps[d2] == sol.ps[d3] == sol.ps[d4] == sol.ps[d5] == 0.0 + end + + # Checks that everything works for vector-valued parameters and variables. + let + # Creates the model using the macro. + @mtkmodel VectorParams begin + @parameters begin + k(t)[1:2] = [1, 1] + kup = 2.0 + end + @variables begin + X(t)[1:2] = [4.0, 4.0] + end + @equations begin + D(X[1]) ~ -k[1]*X[1] + k[2]*X[2] + D(X[2]) ~ k[1]*X[1] - k[2]*X[2] + end + @continuous_events begin + (k[2] ~ t) => [k[1] ~ Pre(k[1] + kup)] + end + end + @mtkcompile model = VectorParams() + + # Simulates the model. Checks that the correct values are achieved. + prob = ODEProblem(model, [], (0.0, 100.0)) + sol = solve(prob, Rosenbrock23()) + @test sol.ps[model.kup] == 2.0 + @test sol.ps[model.k[1]] == 3.0 + @test sol.ps[model.k[2]] == 1.0 + @test sol[model.X[1]][end]≈2.0 atol=1e-8 rtol=1e-8 + @test sol[model.X[2]][end]≈6.0 atol=1e-8 rtol=1e-8 + end + + # Checks for a functional affect. + let + # Creates model. + @variables X(t) = 5.0 + @parameters p=2.0 d(t)=1.0 + eqs = [D(X) ~ p - d * X] + affect!(mod, obs, ctx, integ) = return (; d = 2.0) + cevent = [t ~ 1.0] => (f = affect!, modified = (; d)) + @mtkcompile sys = System(eqs, t; continuous_events = [cevent]) + + # Simulates the model and checks that values is correct. + sol = solve(ODEProblem(sys, [], (0.0, 100.0)), Rosenbrock23()) + @test sol[X][end]≈1.0 atol=1e-8 rtol=1e-8 + @test sol.ps[p] == 2.0 + @test sol.ps[d] == [1.0, 2.0] + end +end