Skip to content

Commit 3f96b01

Browse files
authored
Add preface for DiscreteSystem (#1429)
1 parent 72a3978 commit 3f96b01

File tree

3 files changed

+124
-4
lines changed

3 files changed

+124
-4
lines changed

src/systems/discrete_system/discrete_system.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,25 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
5656
"""
5757
structure::Any
5858
"""
59-
type: type of the system
59+
preface: inject assignment statements before the evaluation of the RHS function.
60+
"""
61+
preface::Any
62+
"""
63+
connector_type: type of the system
6064
"""
6165
connector_type::Any
6266
"""
6367
substitutions: substitutions generated by tearing.
6468
"""
6569
substitutions::Any
6670

67-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type, substitutions=nothing; checks::Bool = true)
71+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions=nothing; checks::Bool = true)
6872
if checks
6973
check_variables(dvs, iv)
7074
check_parameters(ps, iv)
7175
all_dimensionless([dvs;ps;iv;ctrls]) || check_units(discreteEqs)
7276
end
73-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, connector_type, substitutions)
77+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions)
7478
end
7579
end
7680

@@ -88,6 +92,7 @@ function DiscreteSystem(
8892
default_u0=Dict(),
8993
default_p=Dict(),
9094
defaults=_merge(Dict(default_u0), Dict(default_p)),
95+
preface=nothing,
9196
connector_type=nothing,
9297
kwargs...,
9398
)
@@ -113,7 +118,7 @@ function DiscreteSystem(
113118
if length(unique(sysnames)) != length(sysnames)
114119
throw(ArgumentError("System names must be unique."))
115120
end
116-
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, nothing, connector_type, kwargs...)
121+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, nothing, preface, connector_type, kwargs...)
117122
end
118123

119124

test/discretesystem.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,60 @@ linearized_eqs = [
123123
RHS2 = RHS
124124
@unpack RHS = fol
125125
@test isequal(RHS, RHS2)
126+
127+
@testset "Preface tests" begin
128+
@parameters t
129+
using OrdinaryDiffEq
130+
using Symbolics
131+
using DiffEqBase: isinplace
132+
using ModelingToolkit
133+
using SymbolicUtils.Code
134+
using SymbolicUtils: Sym
135+
136+
c = [0]
137+
f = function f(c, d::Vector{Float64}, u::Vector{Float64}, p, t::Float64, dt::Float64)
138+
c .= [c[1] + 1]
139+
d .= randn(length(u))
140+
nothing
141+
end
142+
143+
dummy_identity(x, _) = x
144+
@register dummy_identity(x, y)
145+
146+
u0 = ones(5)
147+
p0 = Float64[]
148+
syms = [Symbol(:a, i) for i in 1:5]
149+
syms_p = Symbol[]
150+
dt = 0.1
151+
@assert isinplace(f, 6)
152+
wf = let c=c, buffer = similar(u0), u=similar(u0), p=similar(p0), dt=dt
153+
t -> (f(c, buffer, u, p, t, dt); buffer)
154+
end
155+
156+
num = hash(f) length(u0) length(p0)
157+
buffername = Symbol(:fmi_buffer_, num)
158+
159+
Δ = DiscreteUpdate(t; dt=dt)
160+
us = map(s->(@variables $s(t))[1], syms)
161+
ps = map(s->(@variables $s(t))[1], syms_p)
162+
buffer, = @variables $buffername[1:length(u0)]
163+
dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia
164+
165+
ss = Iterators.flatten((us, ps))
166+
vv = Iterators.flatten((u0, p0))
167+
defs = Dict{Any, Any}(s=>v for (s, v) in zip(ss, vv))
168+
169+
preface = [
170+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:u)), us))
171+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:p)), ps))
172+
Assignment(buffer, term(wf, t))
173+
]
174+
eqs = map(1:length(us)) do i
175+
Δ(us[i]) ~ dummy_identity(buffer[i], us[i])
176+
end
177+
178+
@named sys = DiscreteSystem(eqs, t, us, ps; defaults=defs, preface=preface)
179+
prob = DiscreteProblem(sys, [], (0.0, 1.0))
180+
sol = solve(prob, FunctionMap(); dt=dt)
181+
@test c[1]+1 == length(sol)
182+
end

test/odesystem.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,61 @@ eqs = copy(eqs)
535535
eqs[end] = D(D(z)) ~ α*x - β*y
536536
@named sys = ODESystem(eqs, t, [x,y,z],[α,β])
537537
@test_throws Any ODEFunction(sys)
538+
539+
540+
@testset "Preface tests" begin
541+
using OrdinaryDiffEq
542+
using Symbolics
543+
using DiffEqBase: isinplace
544+
using ModelingToolkit
545+
using SymbolicUtils.Code
546+
using SymbolicUtils: Sym
547+
548+
c = [0]
549+
function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64)
550+
c .= [c[1]+1]
551+
du .= randn(length(u))
552+
nothing
553+
end
554+
555+
dummy_identity(x, _) = x
556+
@register dummy_identity(x, y)
557+
558+
u0 = ones(5)
559+
p0 = Float64[]
560+
syms = [Symbol(:a, i) for i in 1:5]
561+
syms_p = Symbol[]
562+
563+
@assert isinplace(f, 5)
564+
wf = let buffer=similar(u0), u=similar(u0), p=similar(p0), c=c
565+
t -> (f(c, buffer, u, p, t); buffer)
566+
end
567+
568+
num = hash(f) length(u0) length(p0)
569+
buffername = Symbol(:fmi_buffer_, num)
570+
571+
D = Differential(t)
572+
us = map(s->(@variables $s(t))[1], syms)
573+
ps = map(s->(@variables $s(t))[1], syms_p)
574+
buffer, = @variables $buffername[1:length(u0)]
575+
dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia
576+
577+
ss = Iterators.flatten((us, ps))
578+
vv = Iterators.flatten((u0, p0))
579+
defs = Dict{Any, Any}(s=>v for (s, v) in zip(ss, vv))
580+
581+
preface = [
582+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:u)), us))
583+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:p)), ps))
584+
Assignment(buffer, term(wf, t))
585+
]
586+
eqs = map(1:length(us)) do i
587+
D(us[i]) ~ dummy_identity(buffer[i], us[i])
588+
end
589+
590+
@named sys = ODESystem(eqs, t, us, ps; defaults=defs, preface=preface)
591+
prob = ODEProblem(sys, [], (0.0, 1.0))
592+
sol = solve(prob, Euler(); dt=0.1)
593+
594+
@test c[1] == length(sol)
595+
end

0 commit comments

Comments
 (0)