Skip to content

Commit 01a7cf9

Browse files
authored
Merge branch 'master' into cleanup_initialization
2 parents 09fda2f + 40b1f7c commit 01a7cf9

16 files changed

+369
-272
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ PrecompileTools = "1"
112112
RecursiveArrayTools = "3.26"
113113
Reexport = "0.2, 1"
114114
RuntimeGeneratedFunctions = "0.5.9"
115-
SciMLBase = "2.52.1"
115+
SciMLBase = "2.55"
116116
SciMLStructures = "1.0"
117117
Serialization = "1"
118118
Setfield = "0.7, 0.8, 1"
119119
SimpleNonlinearSolve = "0.1.0, 1"
120120
SparseArrays = "1"
121121
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
122122
StaticArrays = "0.10, 0.11, 0.12, 1.0"
123-
SymbolicIndexingInterface = "0.3.29"
123+
SymbolicIndexingInterface = "0.3.31"
124124
SymbolicUtils = "3.7"
125125
Symbolics = "6.12"
126126
URIs = "1"

docs/src/basics/MTKLanguage.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,13 @@ end
6363
@structural_parameters begin
6464
f = sin
6565
N = 2
66-
M = 3
6766
end
6867
begin
6968
v_var = 1.0
7069
end
7170
@variables begin
7271
v(t) = v_var
73-
v_array(t)[1:N, 1:M]
72+
v_array(t)[1:2, 1:3]
7473
v_for_defaults(t)
7574
end
7675
@extend ModelB(; p1)
@@ -311,10 +310,10 @@ end
311310
- `:defaults`: Dictionary of variables and default values specified in the `@defaults`.
312311
- `:extend`: The list of extended unknowns, name given to the base system, and name of the base system.
313312
- `:structural_parameters`: Dictionary of structural parameters mapped to their metadata.
314-
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. Metadata of
315-
the parameter arrays is, for now, omitted.
316-
- `:variables`: Dictionary of symbolic variables mapped to their metadata. Metadata of
317-
the variable arrays is, for now, omitted.
313+
- `:parameters`: Dictionary of symbolic parameters mapped to their metadata. For
314+
parameter arrays, length is added to the metadata as `:size`.
315+
- `:variables`: Dictionary of symbolic variables mapped to their metadata. For
316+
variable arrays, length is added to the metadata as `:size`.
318317
- `:kwargs`: Dictionary of keyword arguments mapped to their metadata.
319318
- `:independent_variable`: Independent variable, which is added while generating the Model.
320319
- `:equations`: List of equations (represented as strings).
@@ -325,10 +324,10 @@ For example, the structure of `ModelC` is:
325324
julia> ModelC.structure
326325
Dict{Symbol, Any} with 10 entries:
327326
:components => Any[Union{Expr, Symbol}[:model_a, :ModelA], Union{Expr, Symbol}[:model_array_a, :ModelA, :(1:N)], Union{Expr, Symbol}[:model_array_b, :ModelA, :(1:N)]]
328-
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_for_defaults=>Dict(:type=>Real))
327+
:variables => Dict{Symbol, Dict{Symbol, Any}}(:v=>Dict(:default=>:v_var, :type=>Real), :v_array=>Dict(:type=>Real, :size=>(2, 3)), :v_for_defaults=>Dict(:type=>Real))
329328
:icon => URI("https://github.com/SciML/SciMLDocs/blob/main/docs/src/assets/logo.png")
330-
:kwargs => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3), :v => Dict{Symbol, Any}(:value => :v_var, :type => Real), :v_for_defaults => Dict{Symbol, Union{Nothing, DataType}}(:value => nothing, :type => Real), :p1 => Dict(:value => nothing)),
331-
:structural_parameters => Dict{Symbol, Dict}(:f => Dict(:value => :sin), :N => Dict(:value => 2), :M => Dict(:value => 3))
329+
:kwargs => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2), :v=>Dict{Symbol, Any}(:value=>:v_var, :type=>Real), :v_array=>Dict{Symbol, Union{Nothing, UnionAll}}(:value=>nothing, :type=>AbstractArray{Real}), :v_for_defaults=>Dict{Symbol, Union{Nothing, DataType}}(:value=>nothing, :type=>Real), :p1=>Dict(:value=>nothing))
330+
:structural_parameters => Dict{Symbol, Dict}(:f=>Dict(:value=>:sin), :N=>Dict(:value=>2))
332331
:independent_variable => t
333332
:constants => Dict{Symbol, Dict}(:c=>Dict{Symbol, Any}(:value=>1, :type=>Int64, :description=>"Example constant."))
334333
:extend => Any[[:p2, :p1], Symbol("#mtkmodel__anonymous__ModelB"), :ModelB]

src/systems/abstractsystem.jl

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
end
231231

232232
function wrap_array_vars(
233-
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
233+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
234+
inputs = nothing, history = false)
234235
isscalar = !(exprs isa AbstractArray)
235236
array_vars = Dict{Any, AbstractArray{Int}}()
236237
if dvs !== nothing
@@ -328,6 +329,19 @@ function wrap_array_vars(
328329
array_parameters[p] = (idxs, buffer_idx, sz)
329330
end
330331
end
332+
333+
inputind = if history
334+
uind + 2
335+
else
336+
uind + 1
337+
end
338+
params_offset = if history && hasinputs
339+
uind + 2
340+
elseif history || hasinputs
341+
uind + 1
342+
else
343+
uind
344+
end
331345
if isscalar
332346
function (expr)
333347
Func(
@@ -336,10 +350,10 @@ function wrap_array_vars(
336350
Let(
337351
vcat(
338352
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
339-
[k :(view($(expr.args[uind + hasinputs].name), $v))
353+
[k :(view($(expr.args[inputind].name), $v))
340354
for (k, v) in input_vars],
341355
[k :(reshape(
342-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
356+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
343357
$sz))
344358
for (k, (idxs, buffer_idx, sz)) in array_parameters],
345359
[k Code.MakeArray(v, symtype(k))
@@ -358,10 +372,10 @@ function wrap_array_vars(
358372
Let(
359373
vcat(
360374
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
361-
[k :(view($(expr.args[uind + hasinputs].name), $v))
375+
[k :(view($(expr.args[inputind].name), $v))
362376
for (k, v) in input_vars],
363377
[k :(reshape(
364-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
378+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
365379
$sz))
366380
for (k, (idxs, buffer_idx, sz)) in array_parameters],
367381
[k Code.MakeArray(v, symtype(k))
@@ -380,10 +394,10 @@ function wrap_array_vars(
380394
vcat(
381395
[k :(view($(expr.args[uind + 1].name), $v))
382396
for (k, v) in array_vars],
383-
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
397+
[k :(view($(expr.args[inputind + 1].name), $v))
384398
for (k, v) in input_vars],
385399
[k :(reshape(
386-
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
400+
view($(expr.args[params_offset + buffer_idx + 1].name),
387401
$idxs),
388402
$sz))
389403
for (k, (idxs, buffer_idx, sz)) in array_parameters],
@@ -398,50 +412,76 @@ function wrap_array_vars(
398412
end
399413
end
400414

401-
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
415+
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)
416+
417+
"""
418+
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
419+
420+
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
421+
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
422+
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
423+
wrapped is for a scalar value. `p_start` is the index of the argument containing
424+
the first parameter vector in the out-of-place version of the function. For example,
425+
if a history function (DDEs) was passed before `p`, then the function before wrapping
426+
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
427+
428+
The returned function is `identity` if the system does not have an `IndexCache`.
429+
"""
430+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
402431
if has_index_cache(sys) && get_index_cache(sys) !== nothing
403432
offset = Int(is_time_dependent(sys))
404433

405434
if isscalar
406435
function (expr)
407-
p = gensym(:p)
436+
param_args = expr.args[p_start:(end - offset)]
437+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
438+
param_buffer_args = param_args[param_buffer_idxs]
439+
destructured_mtkparams = DestructuredArgs(
440+
[x.name for x in param_buffer_args],
441+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
408442
Func(
409443
[
410-
expr.args[1],
411-
DestructuredArgs(
412-
[arg.name for arg in expr.args[2:(end - offset)]], p),
413-
(isone(offset) ? (expr.args[end],) : ())...
444+
expr.args[begin:(p_start - 1)]...,
445+
destructured_mtkparams,
446+
expr.args[(end - offset + 1):end]...
414447
],
415448
[],
416-
Let(expr.args[2:(end - offset)], expr.body, false)
449+
Let(param_buffer_args, expr.body, false)
417450
)
418451
end
419452
else
420453
function (expr)
421-
p = gensym(:p)
454+
param_args = expr.args[p_start:(end - offset)]
455+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
456+
param_buffer_args = param_args[param_buffer_idxs]
457+
destructured_mtkparams = DestructuredArgs(
458+
[x.name for x in param_buffer_args],
459+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
422460
Func(
423461
[
424-
expr.args[1],
425-
DestructuredArgs(
426-
[arg.name for arg in expr.args[2:(end - offset)]], p),
427-
(isone(offset) ? (expr.args[end],) : ())...
462+
expr.args[begin:(p_start - 1)]...,
463+
destructured_mtkparams,
464+
expr.args[(end - offset + 1):end]...
428465
],
429466
[],
430-
Let(expr.args[2:(end - offset)], expr.body, false)
467+
Let(param_buffer_args, expr.body, false)
431468
)
432469
end,
433470
function (expr)
434-
p = gensym(:p)
471+
param_args = expr.args[(p_start + 1):(end - offset)]
472+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
473+
param_buffer_args = param_args[param_buffer_idxs]
474+
destructured_mtkparams = DestructuredArgs(
475+
[x.name for x in param_buffer_args],
476+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
435477
Func(
436478
[
437-
expr.args[1],
438-
expr.args[2],
439-
DestructuredArgs(
440-
[arg.name for arg in expr.args[3:(end - offset)]], p),
441-
(isone(offset) ? (expr.args[end],) : ())...
479+
expr.args[begin:p_start]...,
480+
destructured_mtkparams,
481+
expr.args[(end - offset + 1):end]...
442482
],
443483
[],
444-
Let(expr.args[3:(end - offset)], expr.body, false)
484+
Let(param_buffer_args, expr.body, false)
445485
)
446486
end
447487
end
@@ -669,25 +709,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
669709
if rawobs isa Tuple
670710
if is_time_dependent(sys)
671711
obsfn = let oop = rawobs[1], iip = rawobs[2]
672-
f1a(p::MTKParameters, t) = oop(p..., t)
673-
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
712+
f1a(p, t) = oop(p, t)
713+
f1a(out, p, t) = iip(out, p, t)
674714
end
675715
else
676716
obsfn = let oop = rawobs[1], iip = rawobs[2]
677-
f1b(p::MTKParameters) = oop(p...)
678-
f1b(out, p::MTKParameters) = iip(out, p...)
717+
f1b(p) = oop(p)
718+
f1b(out, p) = iip(out, p)
679719
end
680720
end
681721
else
682-
if is_time_dependent(sys)
683-
obsfn = let rawobs = rawobs
684-
f2a(p::MTKParameters, t) = rawobs(p..., t)
685-
end
686-
else
687-
obsfn = let rawobs = rawobs
688-
f2b(p::MTKParameters) = rawobs(p...)
689-
end
690-
end
722+
obsfn = rawobs
691723
end
692724
else
693725
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
@@ -802,17 +834,11 @@ function SymbolicIndexingInterface.observed(
802834
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
803835

804836
if is_time_dependent(sys)
805-
return let _fn = _fn
806-
fn1(u, p, t) = _fn(u, p, t)
807-
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
808-
fn1
809-
end
837+
return _fn
810838
else
811839
return let _fn = _fn
812840
fn2(u, p) = _fn(u, p)
813-
fn2(u, p::MTKParameters) = _fn(u, p...)
814841
fn2(::Nothing, p) = _fn([], p)
815-
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
816842
fn2
817843
end
818844
end
@@ -828,6 +854,8 @@ end
828854
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
829855
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false
830856

857+
SymbolicIndexingInterface.is_markovian(sys::AbstractSystem) = !is_dde(sys)
858+
831859
SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true
832860

833861
function SymbolicIndexingInterface.all_variable_symbols(sys::AbstractSystem)
@@ -971,6 +999,7 @@ for prop in [:eqs
971999
:solved_unknowns
9721000
:split_idxs
9731001
:parent
1002+
:is_dde
9741003
:index_cache
9751004
:is_scalar_noise
9761005
:isscheduled]
@@ -2349,8 +2378,8 @@ function linearization_function(sys::AbstractSystem, inputs,
23492378
u_getter = u_getter
23502379

23512380
function (u, p, t)
2352-
p_setter!(oldps, p_getter(u, p..., t))
2353-
newu = u_getter(u, p..., t)
2381+
p_setter!(oldps, p_getter(u, p, t))
2382+
newu = u_getter(u, p, t)
23542383
return newu, oldps
23552384
end
23562385
end
@@ -2361,20 +2390,15 @@ function linearization_function(sys::AbstractSystem, inputs,
23612390

23622391
function (u, p, t)
23632392
state = ProblemState(; u, p, t)
2364-
return u_getter(state), p_getter(state)
2393+
return u_getter(
2394+
state_values(state), parameter_values(state), current_time(state)),
2395+
p_getter(state)
23652396
end
23662397
end
23672398
end
23682399
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
23692400
initprobmap = build_explicit_observed_function(
23702401
initsys, unknowns(sys); eval_expression, eval_module)
2371-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2372-
initprobmap = let inner = initprobmap
2373-
fn(u, p::MTKParameters) = inner(u, p...)
2374-
fn(u, p) = inner(u, p)
2375-
fn
2376-
end
2377-
end
23782402
ps = parameters(sys)
23792403
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
23802404
lin_fun = let diff_idxs = diff_idxs,
@@ -2421,7 +2445,7 @@ function linearization_function(sys::AbstractSystem, inputs,
24212445
fg_xz = ForwardDiff.jacobian(uf, u)
24222446
h_xz = ForwardDiff.jacobian(
24232447
let p = p, t = t
2424-
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
2448+
xz -> h(xz, p, t)
24252449
end, u)
24262450
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
24272451
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -2433,7 +2457,6 @@ function linearization_function(sys::AbstractSystem, inputs,
24332457
end
24342458
hp = let u = u, t = t
24352459
_hp(p) = h(u, p, t)
2436-
_hp(p::MTKParameters) = h(u, p..., t)
24372460
_hp
24382461
end
24392462
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
@@ -2486,7 +2509,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
24862509
dx = fun(sts, p..., t)
24872510

24882511
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
2489-
y = h(sts, p..., t)
2512+
y = h(sts, p, t)
24902513

24912514
fg_xz = Symbolics.jacobian(dx, sts)
24922515
fg_u = Symbolics.jacobian(dx, inputs)
@@ -2955,6 +2978,9 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
29552978
nsys == 0 && return sys
29562979
@set! sys.name = name
29572980
@set! sys.systems = [get_systems(sys); systems]
2981+
if has_is_dde(sys)
2982+
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
2983+
end
29582984
return sys
29592985
end
29602986
function compose(syss...; name = nameof(first(syss)))

0 commit comments

Comments
 (0)