Skip to content

Commit a8567fc

Browse files
authored
Merge branch 'master' into casadi2
2 parents 262d4dc + ebd3131 commit a8567fc

19 files changed

+221
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "9.76.0"
4+
version = "9.78.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/basics/InputOutput.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Now we can test the generated function `f` with random input and state values
7070
p = [1]
7171
x = [rand()]
7272
u = [rand()]
73-
@test f[1](x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
73+
@test f(x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
7474
```
7575

7676
## Generating an output function, ``g``

docs/src/tutorials/disturbance_modeling.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ disturbance_inputs = [ssys.d1, ssys.d2]
184184
P = ssys.system_model
185185
outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w]
186186
187-
(f_oop, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
187+
f, x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
188188
model_with_disturbance, inputs, disturbance_inputs; disturbance_argument = true)
189189
190190
g = ModelingToolkit.build_explicit_observed_function(
@@ -195,12 +195,12 @@ x0, _ = ModelingToolkit.get_u0_p(io_sys, op, op)
195195
p = MTKParameters(io_sys, op)
196196
u = zeros(1) # Control input
197197
w = zeros(length(disturbance_inputs)) # Disturbance input
198-
@test f_oop(x0, u, p, t, w) == zeros(5)
198+
@test f(x0, u, p, t, w) == zeros(5)
199199
@test g(x0, u, p, 0.0) == [0, 0, 0, 0]
200200
201201
# Non-zero disturbance inputs should result in non-zero state derivatives. We call `sort` since we do not generally know the order of the state variables
202202
w = [1.0, 2.0]
203-
@test sort(f_oop(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
203+
@test sort(f(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
204204
```
205205

206206
## Input signal library

ext/MTKInfiniteOptExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,6 @@ function _solve(prob::AbstractDynamicOptProblem, jump_solver, solver)
403403
DynamicOptSolution(model, sol, input_sol)
404404
end
405405

406-
407406
import InfiniteOpt: JuMP, GeneralVariableRef
408407

409408
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]

src/discretedomain.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ SymbolicUtils.isbinop(::Shift) = false
3939

4040
function (D::Shift)(x, allow_zero = false)
4141
!allow_zero && D.steps == 0 && return x
42-
Term{symtype(x)}(D, Any[x])
42+
if Symbolics.isarraysymbolic(x)
43+
Symbolics.array_term(D, x)
44+
else
45+
term(D, x)
46+
end
4347
end
44-
function (D::Shift)(x::Num, allow_zero = false)
48+
function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false)
4549
!allow_zero && D.steps == 0 && return x
4650
vt = value(x)
4751
if iscall(vt)
@@ -52,11 +56,11 @@ function (D::Shift)(x::Num, allow_zero = false)
5256
if D.t === nothing || isequal(D.t, op.t)
5357
arg = arguments(vt)[1]
5458
newsteps = D.steps + op.steps
55-
return Num(newsteps == 0 ? arg : Shift(D.t, newsteps)(arg))
59+
return wrap(newsteps == 0 ? arg : Shift(D.t, newsteps)(arg))
5660
end
5761
end
5862
end
59-
Num(D(vt, allow_zero))
63+
wrap(D(vt, allow_zero))
6064
end
6165
SymbolicUtils.promote_symtype(::Shift, t) = t
6266

@@ -202,11 +206,19 @@ function (xn::Num)(k::ShiftIndex)
202206
x = value(xn)
203207
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
204208
vars = Symbolics.get_variables(x)
205-
length(vars) == 1 ||
209+
if length(vars) != 1
206210
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
207-
args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t)
208-
length(args) == 1 ||
211+
end
212+
var = only(vars)
213+
if !iscall(var)
214+
throw(ArgumentError("Cannot shift time-independent variable $var"))
215+
end
216+
if operation(var) == getindex
217+
var = first(arguments(var))
218+
end
219+
if length(arguments(var)) != 1
209220
error("Cannot shift an expression with multiple independent variables $x.")
221+
end
210222

211223
# d, _ = propagate_time_domain(xn)
212224
# if d != clock # this is only required if the variable has another clock
@@ -220,6 +232,34 @@ function (xn::Num)(k::ShiftIndex)
220232
Shift(t, steps)(xn) # a shift of k steps
221233
end
222234

235+
function (xn::Symbolics.Arr)(k::ShiftIndex)
236+
@unpack clock, steps = k
237+
x = value(xn)
238+
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
239+
vars = ModelingToolkit.vars(x)
240+
if length(vars) != 1
241+
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
242+
end
243+
var = only(vars)
244+
if !iscall(var)
245+
throw(ArgumentError("Cannot shift time-independent variable $var"))
246+
end
247+
if length(arguments(var)) != 1
248+
error("Cannot shift an expression with multiple independent variables $x.")
249+
end
250+
251+
# d, _ = propagate_time_domain(xn)
252+
# if d != clock # this is only required if the variable has another clock
253+
# xn = Sample(t, clock)(xn)
254+
# end
255+
# QUESTION: should we return a variable with time domain set to k.clock?
256+
xn = wrap(setmetadata(unwrap(xn), VariableTimeDomain, k.clock))
257+
if steps == 0
258+
return xn # x(k) needs no shift operator if the step of k is 0
259+
end
260+
Shift(t, steps)(xn) # a shift of k steps
261+
end
262+
223263
Base.:+(k::ShiftIndex, i::Int) = ShiftIndex(k.clock, k.steps + i)
224264
Base.:-(k::ShiftIndex, i::Int) = k + (-i)
225265

src/inputoutput.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,17 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
160160
# Build control function
161161

162162
"""
163-
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
163+
f, x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166166
disturbance_inputs = nothing;
167167
implicit_dae = false,
168168
simplify = false,
169169
)
170170
171-
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `in`
171+
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `u`
172172
173+
The returned function `f` can be called in the out-of-place or in-place form:
173174
```
174175
f_oop : (x,u,p,t) -> rhs
175176
f_ip : (xout,x,u,p,t) -> nothing
@@ -190,7 +191,7 @@ f, x_sym, ps = generate_control_function(sys, expression=Val{false}, simplify=fa
190191
p = varmap_to_vars(defaults(sys), ps)
191192
x = varmap_to_vars(defaults(sys), x_sym)
192193
t = 0
193-
f[1](x, inputs, p, t)
194+
f(x, inputs, p, t)
194195
```
195196
"""
196197
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys),

src/structural_transformation/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,13 @@ function shift2term(var)
471471
op = operation(var)
472472
iv = op.t
473473
arg = only(arguments(var))
474+
if operation(arg) === getindex
475+
idxs = arguments(arg)[2:end]
476+
newvar = shift2term(op(first(arguments(arg))))[idxs...]
477+
unshifted = ModelingToolkit.getunshifted(newvar)[idxs...]
478+
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted)
479+
return newvar
480+
end
474481
is_lowered = !isnothing(ModelingToolkit.getunshifted(arg))
475482

476483
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps

src/systems/abstractsystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,6 +2116,8 @@ function n_expanded_connection_equations(sys::AbstractSystem)
21162116
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
21172117
end
21182118

2119+
Base.show(io::IO, sys::AbstractSystem; kws...) = show(io, MIME"text/plain"(), sys; kws...)
2120+
21192121
function Base.show(
21202122
io::IO, mime::MIME"text/plain", sys::AbstractSystem; hint = true, bold = true)
21212123
limit = get(io, :limit, false) # if output should be limited,

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ function SciMLBase.DiscreteProblem(
319319
iv = get_iv(sys)
320320

321321
u0map = to_varmap(u0map, dvs)
322+
scalarize_varmap!(u0map)
322323
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
323324
f, u0, p = process_SciMLProblem(
324325
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, build_initializeprob = false)

src/systems/imperative_affect.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,22 @@ end
101101

102102
namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
103103
function namespace_affect(affect::ImperativeAffect, s)
104+
rmn = []
105+
for modded in modified(affect)
106+
if symbolic_type(modded) == NotSymbolic() && modded isa AbstractArray
107+
res = []
108+
for m in modded
109+
push!(res, renamespace(s, m))
110+
end
111+
push!(rmn, res)
112+
else
113+
push!(rmn, renamespace(s, modded))
114+
end
115+
end
104116
ImperativeAffect(func(affect),
105117
namespace_expr.(observed(affect), (s,)),
106118
observed_syms(affect),
107-
renamespace.((s,), modified(affect)),
119+
rmn,
108120
modified_syms(affect),
109121
context(affect),
110122
affect.skip_checks)

0 commit comments

Comments
 (0)