Skip to content

Commit b96cd2e

Browse files
committed
More sanity checking
1 parent 49d48b8 commit b96cd2e

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

src/systems/callbacks.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,10 +818,10 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs.
818818
end
819819
end
820820

821-
invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), vars(expr))
821+
invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init=[]))
822822
function unassignable_variables(sys, expr)
823823
assignable_syms = vcat(unknowns(sys), parameters(sys))
824-
return filter(x -> !any(isequal(x), assignable_syms), vars(expr))
824+
return filter(x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init=[]))
825825
end
826826

827827
function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...)
@@ -832,6 +832,21 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
832832
call the affect method - test if it's OOP or IP using applicable
833833
unpack and apply the resulting values
834834
=#
835+
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
836+
seen = Set{Symbol}()
837+
syms_dedup = []; exprs_dedup = []
838+
for (sym, exp) in Iterators.zip(syms, exprs)
839+
if !in(sym, seen)
840+
push!(syms_dedup, sym)
841+
push!(exprs_dedup, exp)
842+
push!(seen, sym)
843+
else
844+
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
845+
end
846+
end
847+
return (syms_dedup, exprs_dedup)
848+
end
849+
835850
obs_exprs = observed(affect)
836851
for oexpr in obs_exprs
837852
invalid_vars = invalid_variables(sys, oexpr)
@@ -840,6 +855,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
840855
end
841856
end
842857
obs_syms = observed_syms(affect)
858+
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
843859
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
844860

845861
mod_exprs = modified(affect)
@@ -849,12 +865,18 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
849865
end
850866
invalid_vars = unassignable_variables(sys, mexpr)
851867
if length(invalid_vars) > 0
852-
error("Observed equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
868+
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
853869
end
854870
end
855871
mod_syms = modified_syms(affect)
872+
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
856873
_, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true)
857874

875+
overlapping_syms = intersect(mod_syms, obs_syms)
876+
if length(overlapping_syms) > 0
877+
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
878+
end
879+
858880
# sanity checks done! now build the data and update function for observed values
859881
mkzero(sz) = if sz === () 0.0 else zeros(sz) end
860882
_, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true)

test/symbolic_events.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,46 @@ end
10211021
@test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49)
10221022
end
10231023

1024+
@testset "MutatingFunctionalAffect errors and warnings" begin
1025+
@variables temp(t)
1026+
params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false
1027+
eqs = [
1028+
D(temp) ~ furnace_on * furnace_power - temp^2 * leakage
1029+
]
1030+
1031+
furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold],
1032+
ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on)) do x, o, i, c
1033+
x.furnace_on = false
1034+
end)
1035+
@named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off])
1036+
ss = structural_simplify(sys)
1037+
@test_logs (:warn, "The symbols Any[:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value.") prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
1038+
1039+
@variables tempsq(t) # trivially eliminated
1040+
eqs = [
1041+
tempsq ~ temp^2
1042+
D(temp) ~ furnace_on * furnace_power - temp^2 * leakage
1043+
]
1044+
1045+
furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold],
1046+
ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on, tempsq), observed=(; furnace_on)) do x, o, i, c
1047+
x.furnace_on = false
1048+
end)
1049+
@named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
1050+
ss = structural_simplify(sys)
1051+
@test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
1052+
1053+
1054+
@parameters not_actually_here
1055+
furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold],
1056+
ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on, not_actually_here)) do x, o, i, c
1057+
x.furnace_on = false
1058+
end)
1059+
@named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
1060+
ss = structural_simplify(sys)
1061+
@test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
1062+
end
1063+
10241064
@testset "Quadrature" begin
10251065
@variables theta(t) omega(t)
10261066
params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0
@@ -1040,8 +1080,6 @@ end
10401080
return 0 # err is interpreted as no movement
10411081
end
10421082
end
1043-
# todo: warn about dups
1044-
# todo: warn if a variable appears in both observed and modified
10451083
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
10461084
ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c
10471085
x.hA = x.qA

0 commit comments

Comments
 (0)