Skip to content

Commit a6f781d

Browse files
committed
First pass at MutatingFunctionalAffect
1 parent c1b1af1 commit a6f781d

File tree

5 files changed

+215
-7
lines changed

5 files changed

+215
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1111
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
12+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1213
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1314
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1415
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ using Reexport
5454
using RecursiveArrayTools
5555
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
5656
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import ComponentArrays
5758

5859
using RuntimeGeneratedFunctions
5960
using RuntimeGeneratedFunctions: drop_expr

src/systems/callbacks.jl

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ function Base.hash(a::FunctionalAffect, s::UInt)
6060
hash(a.ctx, s)
6161
end
6262

63-
has_functional_affect(cb) = affects(cb) isa FunctionalAffect
64-
6563
namespace_affect(affect, s) = namespace_equation(affect, s)
6664
function namespace_affect(affect::FunctionalAffect, s)
6765
FunctionalAffect(func(affect),
@@ -73,6 +71,67 @@ function namespace_affect(affect::FunctionalAffect, s)
7371
context(affect))
7472
end
7573

74+
"""
75+
`MutatingFunctionalAffect` differs from `FunctionalAffect` in two key ways:
76+
* First, insetad of the `u` vector passed to `f` being a vector of indices into `integ.u` it's instead the result of evaluating `obs` at the current state, named as specified in `obs_syms`. This allows affects to easily access observed states and decouples affect inputs from the system structure.
77+
* Second, it abstracts the assignment back to system states away. Instead of writing `integ.u[u.myvar] = [whatever]`, you instead declare in `mod_params` that you want to modify `myvar` and then either (out of place) return a named tuple with `myvar` or (in place) modify the associated element in the ComponentArray that's given.
78+
Initially, we only support "flat" states in `modified`; these states will be marked as irreducible in the overarching system and they will simply be bulk assigned at mutation. In the future, this will be extended to perform a nonlinear solve to further decouple the affect from the system structure.
79+
"""
80+
@kwdef struct MutatingFunctionalAffect
81+
f::Any
82+
obs::Vector
83+
obs_syms::Vector{Symbol}
84+
modified::Vector
85+
mod_syms::Vector{Symbol}
86+
ctx::Any
87+
end
88+
89+
MutatingFunctionalAffect(f::Function;
90+
observed::NamedTuple = NamedTuple{()}(()),
91+
modified::NamedTuple = NamedTuple{()}(()),
92+
ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx)
93+
MutatingFunctionalAffect(f::Function, observed::NamedTuple; modified::NamedTuple = NamedTuple{()}(()), ctx=nothing) =
94+
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
95+
MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple; ctx=nothing) =
96+
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
97+
MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple, ctx) =
98+
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
99+
100+
func(f::MutatingFunctionalAffect) = f.f
101+
context(a::MutatingFunctionalAffect) = a.ctx
102+
observed(a::MutatingFunctionalAffect) = a.obs
103+
observed_syms(a::MutatingFunctionalAffect) = a.obs_syms
104+
discretes(a::MutatingFunctionalAffect) = filter(ModelingToolkit.isparameter, a.modified)
105+
modified(a::MutatingFunctionalAffect) = a.modified
106+
modified_syms(a::MutatingFunctionalAffect) = a.mod_syms
107+
108+
function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect)
109+
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
110+
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms)&& isequal(a1.ctx, a2.ctx)
111+
end
112+
113+
function Base.hash(a::MutatingFunctionalAffect, s::UInt)
114+
s = hash(a.f, s)
115+
s = hash(a.obs, s)
116+
s = hash(a.obs_syms, s)
117+
s = hash(a.modified, s)
118+
s = hash(a.mod_syms, s)
119+
hash(a.ctx, s)
120+
end
121+
122+
function namespace_affect(affect::MutatingFunctionalAffect, s)
123+
MutatingFunctionalAffect(func(affect),
124+
renamespace.((s,), observed(affect)),
125+
observed_syms(affect),
126+
renamespace.((s,), modified(affect)),
127+
modified_syms(affect),
128+
context(affect))
129+
end
130+
131+
function has_functional_affect(cb)
132+
(affects(cb) isa FunctionalAffect || affects(cb) isa MutatingFunctionalAffect)
133+
end
134+
76135
#################################### continuous events #####################################
77136

78137
const NULL_AFFECT = Equation[]
@@ -109,8 +168,8 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
109168
"""
110169
struct SymbolicContinuousCallback
111170
eqs::Vector{Equation}
112-
affect::Union{Vector{Equation}, FunctionalAffect}
113-
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
171+
affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
172+
affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
114173
rootfind::SciMLBase.RootfindOpt
115174
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
116175
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
@@ -250,14 +309,15 @@ scalarize_affects(affects) = scalarize(affects)
250309
scalarize_affects(affects::Tuple) = FunctionalAffect(affects...)
251310
scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...)
252311
scalarize_affects(affects::FunctionalAffect) = affects
312+
scalarize_affects(affects::MutatingFunctionalAffect) = affects
253313

254314
SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2])
255315
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough
256316

257317
function Base.show(io::IO, db::SymbolicDiscreteCallback)
258318
println(io, "condition: ", db.condition)
259319
println(io, "affects:")
260-
if db.affects isa FunctionalAffect
320+
if db.affects isa FunctionalAffect || db.affects isa MutatingFunctionalAffect
261321
# TODO
262322
println(io, " ", db.affects)
263323
else
@@ -749,6 +809,80 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs.
749809
end
750810
end
751811

812+
invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), vars(expr))
813+
function unassignable_variables(sys, expr)
814+
assignable_syms = vcat(unknowns(sys), parameters(sys))
815+
return filter(x -> !any(isequal(x), assignable_syms), vars(expr))
816+
end
817+
818+
function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...)
819+
#=
820+
Implementation sketch:
821+
generate observed function (oop), should save to a component array under obs_syms
822+
do the same stuff as the normal FA for pars_syms
823+
call the affect method - test if it's OOP or IP using applicable
824+
unpack and apply the resulting values
825+
=#
826+
obs_exprs = observed(affect)
827+
for oexpr in obs_exprs
828+
invalid_vars = invalid_variables(sys, oexpr)
829+
if length(invalid_vars) > 0
830+
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
831+
end
832+
end
833+
obs_syms = observed_syms(affect)
834+
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
835+
836+
mod_exprs = modified(affect)
837+
for mexpr in mod_exprs
838+
if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing
839+
error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
840+
end
841+
invalid_vars = unassignable_variables(sys, mexpr)
842+
if length(invalid_vars) > 0
843+
error("Observed equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
844+
end
845+
end
846+
mod_syms = modified_syms(affect)
847+
_, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true)
848+
849+
# sanity checks done! now build the data and update function for observed values
850+
mkzero(sz) = if sz === () 0.0 else zeros(sz) end
851+
_, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true)
852+
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms..., )}(mkzero.(obs_size)))
853+
854+
# okay so now to generate the stuff to assign it back into the system
855+
# note that we reorder the componentarray to make the views coherent wrt the base array
856+
mod_pairs = mod_exprs .=> mod_syms
857+
mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs)
858+
mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs)
859+
_, mod_og_val_fun = build_explicit_observed_function(sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); return_inplace=true)
860+
upd_params_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
861+
upd_unk_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
862+
863+
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)}(
864+
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
865+
collect(mkzero(size(e)) for e in first.(mod_unk_pairs))]))
866+
upd_params_view = view(upd_component_array, last.(mod_param_pairs))
867+
upd_unks_view = view(upd_component_array, last.(mod_unk_pairs))
868+
let user_affect = func(affect), ctx = context(affect)
869+
function (integ)
870+
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
871+
mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t)
872+
873+
# update the observed values
874+
obs_fun(obs_component_array, integ.u, integ.p..., integ.t)
875+
876+
# let the user do their thing
877+
user_affect(upd_component_array, obs_component_array, integ, ctx)
878+
879+
# write the new values back to the integrator
880+
upd_params_fun(integ, upd_params_view)
881+
upd_unk_fun(integ, upd_unks_view)
882+
end
883+
end
884+
end
885+
752886
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
753887
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
754888
end

src/systems/diffeqs/odesystem.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,31 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...
404404
"""
405405
$(SIGNATURES)
406406
407-
Build the observed function assuming the observed equations are all explicit,
408-
i.e. there are no cycles.
407+
Generates a function that computes the observed value(s) `ts` in the system `sys` assuming that there are no cycles in the equations.
408+
409+
The return value will be either:
410+
* a single function if the input is a scalar or if the input is a Vector but `return_inplace` is false
411+
* the out of place and in-place functions `(ip, oop)` if `return_inplace` is true and the input is a `Vector`
412+
413+
The function(s) will be:
414+
* `RuntimeGeneratedFunction`s by default,
415+
* A Julia `Expr` if `expression` is true,
416+
* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true
417+
418+
The signatures will be of the form `g(...)` with arguments:
419+
* `output` for in-place functions
420+
* `unknowns` if `params_only` is `false`
421+
* `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts`
422+
* `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted
423+
* `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t`
424+
For example, a function `g(op, unknowns, p, inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector, an array of inputs `inputs` is given, and `params_only` is false for a time-dependent system.
425+
426+
Options not otherwise specified are:
427+
* `output_type = Array` the type of the array generated by the out-of-place vector-valued function
428+
* `checkbounds = true` checks bounds if true when destructuring parameters
429+
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
430+
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
431+
* `drop_expr` is deprecated.
409432
"""
410433
function build_explicit_observed_function(sys, ts;
411434
inputs = nothing,

test/symbolic_events.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,52 @@ end
887887
@test sol[b] == [2.0, 5.0, 5.0]
888888
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
889889
end
890+
@testset "Heater" begin
891+
@variables temp(t)
892+
params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false
893+
eqs = [
894+
D(temp) ~ furnace_on * furnace_power - temp^2 * leakage
895+
]
896+
897+
furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold],
898+
ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c
899+
x.furnace_on = false
900+
end)
901+
furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold],
902+
ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c
903+
x.furnace_on = true
904+
end)
905+
906+
@named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
907+
ss = structural_simplify(sys)
908+
prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
909+
sol = solve(prob, Tsit5(); dtmax=0.01)
910+
@test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49)
911+
end
912+
913+
@testset "Quadrature" begin
914+
@variables theta(t) omega(t)
915+
params = @parameters qA=0 qB=0
916+
eqs = [
917+
D(theta) ~ omega
918+
omega ~ sin(0.5*t)
919+
]
920+
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta) ~ 0],
921+
ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c
922+
x.qA = 1
923+
end,
924+
affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c
925+
x.qA = 0
926+
end)
927+
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta + π/2) ~ 0],
928+
ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c
929+
x.qB = 1
930+
end,
931+
affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c
932+
x.qB = 0
933+
end)
934+
@named sys = ODESystem(eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt])
935+
ss = structural_simplify(sys)
936+
prob = ODEProblem(ss, [theta => 0.0], (0.0, 1.0))
937+
sol = solve(prob, Tsit5(); dtmax=0.01)
938+
end

0 commit comments

Comments
 (0)