Skip to content

Commit 728d28d

Browse files
committed
Fix after rebase
1 parent 426e6bb commit 728d28d

27 files changed

+222
-55
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
- {user: SciML, repo: MethodOfLines.jl, group: DAE}
4040
- {user: SciML, repo: ModelingToolkitNeuralNets.jl, group: All}
4141

42-
- {user: Neuroblox, repo: Neuroblox.jl, group: NNPDE}
42+
- {user: Neuroblox, repo: Neuroblox.jl, group: All}
4343
steps:
4444
- uses: actions/checkout@v4
4545
- uses: julia-actions/setup-julia@v1

docs/src/basics/InputOutput.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Now we can test the generated function `f` with random input and state values
7070
p = [1]
7171
x = [rand()]
7272
u = [rand()]
73-
@test f[1](x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
73+
@test f(x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
7474
```
7575

7676
## Generating an output function, ``g``

docs/src/tutorials/disturbance_modeling.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ disturbance_inputs = [ssys.d1, ssys.d2]
184184
P = ssys.system_model
185185
outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w]
186186
187-
(f_oop, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
187+
f, x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
188188
model_with_disturbance, inputs, disturbance_inputs; disturbance_argument = true)
189189
190190
g = ModelingToolkit.build_explicit_observed_function(
@@ -195,12 +195,12 @@ x0, _ = ModelingToolkit.get_u0_p(io_sys, op, op)
195195
p = MTKParameters(io_sys, op)
196196
u = zeros(1) # Control input
197197
w = zeros(length(disturbance_inputs)) # Disturbance input
198-
@test f_oop(x0, u, p, t, w) == zeros(5)
198+
@test f(x0, u, p, t, w) == zeros(5)
199199
@test g(x0, u, p, 0.0) == [0, 0, 0, 0]
200200
201201
# Non-zero disturbance inputs should result in non-zero state derivatives. We call `sort` since we do not generally know the order of the state variables
202202
w = [1.0, 2.0]
203-
@test sort(f_oop(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
203+
@test sort(f(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
204204
```
205205

206206
## Input signal library

ext/MTKInfiniteOptExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ function constructDefault(T::Type = Float64)
340340
A = map(T, A)
341341
α = map(T, α)
342342
c = map(T, c)
343-
343+
344344
DiffEqBase.ImplicitRKTableau(A, c, α, 5)
345345
end
346346

@@ -422,7 +422,6 @@ function _solve(prob::AbstractDynamicOptProblem, jump_solver, solver)
422422
DynamicOptSolution(model, sol, input_sol)
423423
end
424424

425-
426425
import InfiniteOpt: JuMP, GeneralVariableRef
427426

428427
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]

src/discretedomain.jl

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ SymbolicUtils.isbinop(::Shift) = false
3939

4040
function (D::Shift)(x, allow_zero = false)
4141
!allow_zero && D.steps == 0 && return x
42-
Term{symtype(x)}(D, Any[x])
42+
if Symbolics.isarraysymbolic(x)
43+
Symbolics.array_term(D, x)
44+
else
45+
term(D, x)
46+
end
4347
end
44-
function (D::Shift)(x::Num, allow_zero = false)
48+
function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false)
4549
!allow_zero && D.steps == 0 && return x
4650
vt = value(x)
4751
if iscall(vt)
@@ -52,11 +56,11 @@ function (D::Shift)(x::Num, allow_zero = false)
5256
if D.t === nothing || isequal(D.t, op.t)
5357
arg = arguments(vt)[1]
5458
newsteps = D.steps + op.steps
55-
return Num(newsteps == 0 ? arg : Shift(D.t, newsteps)(arg))
59+
return wrap(newsteps == 0 ? arg : Shift(D.t, newsteps)(arg))
5660
end
5761
end
5862
end
59-
Num(D(vt, allow_zero))
63+
wrap(D(vt, allow_zero))
6064
end
6165
SymbolicUtils.promote_symtype(::Shift, t) = t
6266

@@ -202,11 +206,19 @@ function (xn::Num)(k::ShiftIndex)
202206
x = value(xn)
203207
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
204208
vars = Symbolics.get_variables(x)
205-
length(vars) == 1 ||
209+
if length(vars) != 1
206210
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
207-
args = Symbolics.arguments(vars[]) # args should be one element vector with the t in x(t)
208-
length(args) == 1 ||
211+
end
212+
var = only(vars)
213+
if !iscall(var)
214+
throw(ArgumentError("Cannot shift time-independent variable $var"))
215+
end
216+
if operation(var) == getindex
217+
var = first(arguments(var))
218+
end
219+
if length(arguments(var)) != 1
209220
error("Cannot shift an expression with multiple independent variables $x.")
221+
end
210222

211223
# d, _ = propagate_time_domain(xn)
212224
# if d != clock # this is only required if the variable has another clock
@@ -220,6 +232,34 @@ function (xn::Num)(k::ShiftIndex)
220232
Shift(t, steps)(xn) # a shift of k steps
221233
end
222234

235+
function (xn::Symbolics.Arr)(k::ShiftIndex)
236+
@unpack clock, steps = k
237+
x = value(xn)
238+
# Verify that the independent variables of k and x match and that the expression doesn't have multiple variables
239+
vars = ModelingToolkit.vars(x)
240+
if length(vars) != 1
241+
error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.")
242+
end
243+
var = only(vars)
244+
if !iscall(var)
245+
throw(ArgumentError("Cannot shift time-independent variable $var"))
246+
end
247+
if length(arguments(var)) != 1
248+
error("Cannot shift an expression with multiple independent variables $x.")
249+
end
250+
251+
# d, _ = propagate_time_domain(xn)
252+
# if d != clock # this is only required if the variable has another clock
253+
# xn = Sample(t, clock)(xn)
254+
# end
255+
# QUESTION: should we return a variable with time domain set to k.clock?
256+
xn = wrap(setmetadata(unwrap(xn), VariableTimeDomain, k.clock))
257+
if steps == 0
258+
return xn # x(k) needs no shift operator if the step of k is 0
259+
end
260+
Shift(t, steps)(xn) # a shift of k steps
261+
end
262+
223263
Base.:+(k::ShiftIndex, i::Int) = ShiftIndex(k.clock, k.steps + i)
224264
Base.:-(k::ShiftIndex, i::Int) = k + (-i)
225265

src/inputoutput.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,17 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
160160
# Build control function
161161

162162
"""
163-
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
163+
f, x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166166
disturbance_inputs = disturbances(sys);
167167
implicit_dae = false,
168168
simplify = false,
169169
)
170170
171-
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `in`
171+
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate functions with additional input argument `u`
172172
173+
The returned function `f` can be called in the out-of-place or in-place form:
173174
```
174175
f_oop : (x,u,p,t) -> rhs
175176
f_ip : (xout,x,u,p,t) -> nothing
@@ -187,7 +188,7 @@ f, x_sym, ps = generate_control_function(sys, expression=Val{false}, simplify=fa
187188
p = varmap_to_vars(defaults(sys), ps)
188189
x = varmap_to_vars(defaults(sys), x_sym)
189190
t = 0
190-
f[1](x, inputs, p, t)
191+
f(x, inputs, p, t)
191192
```
192193
"""
193194
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys),
@@ -249,9 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
249250
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
250251
p_end = length(p) + 2 + implicit_dae, kwargs...)
251252
f = eval_or_rgf.(f; eval_expression, eval_module)
252-
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
253+
f = GeneratedFunctionWrapper{(3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
253254
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
254-
(; f, dvs, ps, io_sys = sys)
255+
(; f = (f, f), dvs, ps, io_sys = sys)
255256
end
256257

257258
"""
@@ -418,7 +419,7 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = Any[]; kwar
418419
augmented_sys = extend(augmented_sys, sys)
419420
ssys = structural_simplify(augmented_sys, inputs = all_inputs, disturbance_inputs = [d])
420421

421-
f, dvs, p, io_sys = generate_control_function(ssys, all_inputs,
422+
(f_oop, f_ip), dvs, p, io_sys = generate_control_function(ssys, all_inputs,
422423
[d]; kwargs...)
423-
f, augmented_sys, dvs, p, io_sys
424+
(f_oop, f_ip), augmented_sys, dvs, p, io_sys
424425
end

src/linearization.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ function eq_idxs(sys::AbstractSystem)
143143
diff_idxs, alge_idxs
144144
end
145145

146+
"""
147+
Return the set of indexes of differential equations and algebraic equations in the simplified system.
148+
"""
149+
function eq_idxs(sys::AbstractSystem)
150+
eqs = equations(sys)
151+
alge_idxs = findall(!isdiffeq, eqs)
152+
diff_idxs = setdiff(1:length(eqs), alge_idxs)
153+
154+
diff_idxs, alge_idxs
155+
end
156+
146157
"""
147158
$(TYPEDEF)
148159
@@ -607,7 +618,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true
607618
end
608619
(all(values(outputset)) || error(
609620
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
610-
outputset))
621+
outputset))
611622
end
612623
state, orig_inputs
613624
end

src/structural_transformation/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,13 @@ function shift2term(var)
472472
op = operation(var)
473473
iv = op.t
474474
arg = only(arguments(var))
475+
if operation(arg) === getindex
476+
idxs = arguments(arg)[2:end]
477+
newvar = shift2term(op(first(arguments(arg))))[idxs...]
478+
unshifted = ModelingToolkit.getunshifted(newvar)[idxs...]
479+
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted)
480+
return newvar
481+
end
475482
is_lowered = !isnothing(ModelingToolkit.getunshifted(arg))
476483

477484
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps

src/systems/discrete_system/discrete_system.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ function SciMLBase.DiscreteProblem(
319319
iv = get_iv(sys)
320320

321321
u0map = to_varmap(u0map, dvs)
322+
scalarize_varmap!(u0map)
322323
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
323324
f, u0, p = process_SciMLProblem(
324325
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, build_initializeprob = false)

src/systems/if_lifting.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ function (cw::CondRewriter)(expr, dep)
111111
# and ELSE branch is true
112112
# similarly for expression being false
113113
return (ifelse(rw_cond, rw_conda, rw_condb),
114-
implies(ctrue, truea) | implies(cfalse, trueb),
115-
implies(ctrue, falsea) | implies(cfalse, falseb))
114+
ctrue & truea | cfalse & trueb,
115+
ctrue & falsea | cfalse & falseb)
116116
elseif operation(expr) == Base.:(!) # NOT of expression
117117
(a,) = arguments(expr)
118118
(rw, ctrue, cfalse) = cw(a, dep)

0 commit comments

Comments
 (0)