Skip to content

Commit 5a832ec

Browse files
committed
Merge branch 'master' into kf/state_selection
2 parents 36aa04a + 2f3a382 commit 5a832ec

File tree

4 files changed

+130
-6
lines changed

4 files changed

+130
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "8.3.1"
4+
version = "8.3.2"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -72,7 +72,7 @@ SciMLBase = "1.3"
7272
Setfield = "0.7, 0.8"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"
75-
SymbolicUtils = "0.18, 0.19"
75+
SymbolicUtils = "0.19"
7676
Symbolics = "4.0.0"
7777
UnPack = "0.1, 1.0"
7878
Unitful = "1.1"

src/systems/discrete_system/discrete_system.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,29 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
5252
"""
5353
defaults::Dict
5454
"""
55-
type: type of the system
55+
structure: structural information of the system
56+
"""
57+
structure::Any
58+
"""
59+
preface: inject assignment statements before the evaluation of the RHS function.
60+
"""
61+
preface::Any
62+
"""
63+
connector_type: type of the system
5664
"""
5765
connector_type::Any
5866
"""
5967
substitutions: substitutions generated by tearing.
6068
"""
6169
substitutions::Any
6270

63-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connector_type, substitutions=nothing; checks::Bool = true)
71+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions=nothing; checks::Bool = true)
6472
if checks
6573
check_variables(dvs, iv)
6674
check_parameters(ps, iv)
6775
all_dimensionless([dvs;ps;iv;ctrls]) || check_units(discreteEqs)
6876
end
69-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, connector_type, substitutions)
77+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults, structure, preface, connector_type, substitutions)
7078
end
7179
end
7280

@@ -84,6 +92,7 @@ function DiscreteSystem(
8492
default_u0=Dict(),
8593
default_p=Dict(),
8694
defaults=_merge(Dict(default_u0), Dict(default_p)),
95+
preface=nothing,
8796
connector_type=nothing,
8897
kwargs...,
8998
)
@@ -109,7 +118,7 @@ function DiscreteSystem(
109118
if length(unique(sysnames)) != length(sysnames)
110119
throw(ArgumentError("System names must be unique."))
111120
end
112-
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, connector_type, kwargs...)
121+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, defaults, nothing, preface, connector_type, kwargs...)
113122
end
114123

115124

test/discretesystem.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,60 @@ linearized_eqs = [
123123
RHS2 = RHS
124124
@unpack RHS = fol
125125
@test isequal(RHS, RHS2)
126+
127+
@testset "Preface tests" begin
128+
@parameters t
129+
using OrdinaryDiffEq
130+
using Symbolics
131+
using DiffEqBase: isinplace
132+
using ModelingToolkit
133+
using SymbolicUtils.Code
134+
using SymbolicUtils: Sym
135+
136+
c = [0]
137+
f = function f(c, d::Vector{Float64}, u::Vector{Float64}, p, t::Float64, dt::Float64)
138+
c .= [c[1] + 1]
139+
d .= randn(length(u))
140+
nothing
141+
end
142+
143+
dummy_identity(x, _) = x
144+
@register dummy_identity(x, y)
145+
146+
u0 = ones(5)
147+
p0 = Float64[]
148+
syms = [Symbol(:a, i) for i in 1:5]
149+
syms_p = Symbol[]
150+
dt = 0.1
151+
@assert isinplace(f, 6)
152+
wf = let c=c, buffer = similar(u0), u=similar(u0), p=similar(p0), dt=dt
153+
t -> (f(c, buffer, u, p, t, dt); buffer)
154+
end
155+
156+
num = hash(f) length(u0) length(p0)
157+
buffername = Symbol(:fmi_buffer_, num)
158+
159+
Δ = DiscreteUpdate(t; dt=dt)
160+
us = map(s->(@variables $s(t))[1], syms)
161+
ps = map(s->(@variables $s(t))[1], syms_p)
162+
buffer, = @variables $buffername[1:length(u0)]
163+
dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia
164+
165+
ss = Iterators.flatten((us, ps))
166+
vv = Iterators.flatten((u0, p0))
167+
defs = Dict{Any, Any}(s=>v for (s, v) in zip(ss, vv))
168+
169+
preface = [
170+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:u)), us))
171+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:p)), ps))
172+
Assignment(buffer, term(wf, t))
173+
]
174+
eqs = map(1:length(us)) do i
175+
Δ(us[i]) ~ dummy_identity(buffer[i], us[i])
176+
end
177+
178+
@named sys = DiscreteSystem(eqs, t, us, ps; defaults=defs, preface=preface)
179+
prob = DiscreteProblem(sys, [], (0.0, 1.0))
180+
sol = solve(prob, FunctionMap(); dt=dt)
181+
@test c[1]+1 == length(sol)
182+
end

test/odesystem.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,61 @@ eqs = copy(eqs)
534534
eqs[end] = D(D(z)) ~ α*x - β*y
535535
@named sys = ODESystem(eqs, t, [x,y,z],[α,β])
536536
@test_throws Any ODEFunction(sys)
537+
538+
539+
@testset "Preface tests" begin
540+
using OrdinaryDiffEq
541+
using Symbolics
542+
using DiffEqBase: isinplace
543+
using ModelingToolkit
544+
using SymbolicUtils.Code
545+
using SymbolicUtils: Sym
546+
547+
c = [0]
548+
function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64)
549+
c .= [c[1]+1]
550+
du .= randn(length(u))
551+
nothing
552+
end
553+
554+
dummy_identity(x, _) = x
555+
@register dummy_identity(x, y)
556+
557+
u0 = ones(5)
558+
p0 = Float64[]
559+
syms = [Symbol(:a, i) for i in 1:5]
560+
syms_p = Symbol[]
561+
562+
@assert isinplace(f, 5)
563+
wf = let buffer=similar(u0), u=similar(u0), p=similar(p0), c=c
564+
t -> (f(c, buffer, u, p, t); buffer)
565+
end
566+
567+
num = hash(f) length(u0) length(p0)
568+
buffername = Symbol(:fmi_buffer_, num)
569+
570+
D = Differential(t)
571+
us = map(s->(@variables $s(t))[1], syms)
572+
ps = map(s->(@variables $s(t))[1], syms_p)
573+
buffer, = @variables $buffername[1:length(u0)]
574+
dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia
575+
576+
ss = Iterators.flatten((us, ps))
577+
vv = Iterators.flatten((u0, p0))
578+
defs = Dict{Any, Any}(s=>v for (s, v) in zip(ss, vv))
579+
580+
preface = [
581+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:u)), us))
582+
Assignment(dummy_var, SetArray(true, term(getfield, wf, Meta.quot(:p)), ps))
583+
Assignment(buffer, term(wf, t))
584+
]
585+
eqs = map(1:length(us)) do i
586+
D(us[i]) ~ dummy_identity(buffer[i], us[i])
587+
end
588+
589+
@named sys = ODESystem(eqs, t, us, ps; defaults=defs, preface=preface)
590+
prob = ODEProblem(sys, [], (0.0, 1.0))
591+
sol = solve(prob, Euler(); dt=0.1)
592+
593+
@test c[1] == length(sol)
594+
end

0 commit comments

Comments
 (0)