Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ PrecompileTools = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.52.1"
SciMLBase = "2.55"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.29"
SymbolicIndexingInterface = "0.3.31"
SymbolicUtils = "3.7"
Symbolics = "6.12"
URIs = "1"
Expand Down
146 changes: 86 additions & 60 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
inputs = nothing, history = false)
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
if dvs !== nothing
Expand Down Expand Up @@ -328,6 +329,19 @@ function wrap_array_vars(
array_parameters[p] = (idxs, buffer_idx, sz)
end
end

inputind = if history
uind + 2
else
uind + 1
end
params_offset = if history && hasinputs
uind + 2
elseif history || hasinputs
uind + 1
else
uind
end
if isscalar
function (expr)
Func(
Expand All @@ -336,10 +350,10 @@ function wrap_array_vars(
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(view($(expr.args[uind + hasinputs].name), $v))
[k ← :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
Expand All @@ -358,10 +372,10 @@ function wrap_array_vars(
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(view($(expr.args[uind + hasinputs].name), $v))
[k ← :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
Expand All @@ -380,10 +394,10 @@ function wrap_array_vars(
vcat(
[k ← :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k ← :(view($(expr.args[uind + hasinputs + 1].name), $v))
[k ← :(view($(expr.args[inputind + 1].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
view($(expr.args[params_offset + buffer_idx + 1].name),
$idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
Expand All @@ -398,50 +412,76 @@ function wrap_array_vars(
end
end

function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)

"""
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)

Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
wrapped is for a scalar value. `p_start` is the index of the argument containing
the first parameter vector in the out-of-place version of the function. For example,
if a history function (DDEs) was passed before `p`, then the function before wrapping
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.

The returned function is `identity` if the system does not have an `IndexCache`.
"""
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
offset = Int(is_time_dependent(sys))

if isscalar
function (expr)
p = gensym(:p)
param_args = expr.args[p_start:(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
DestructuredArgs(
[arg.name for arg in expr.args[2:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:(p_start - 1)]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[2:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end
else
function (expr)
p = gensym(:p)
param_args = expr.args[p_start:(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
DestructuredArgs(
[arg.name for arg in expr.args[2:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:(p_start - 1)]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[2:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end,
function (expr)
p = gensym(:p)
param_args = expr.args[(p_start + 1):(end - offset)]
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
param_buffer_args = param_args[param_buffer_idxs]
destructured_mtkparams = DestructuredArgs(
[x.name for x in param_buffer_args],
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
Func(
[
expr.args[1],
expr.args[2],
DestructuredArgs(
[arg.name for arg in expr.args[3:(end - offset)]], p),
(isone(offset) ? (expr.args[end],) : ())...
expr.args[begin:p_start]...,
destructured_mtkparams,
expr.args[(end - offset + 1):end]...
],
[],
Let(expr.args[3:(end - offset)], expr.body, false)
Let(param_buffer_args, expr.body, false)
)
end
end
Expand Down Expand Up @@ -669,25 +709,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
if rawobs isa Tuple
if is_time_dependent(sys)
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1a(p::MTKParameters, t) = oop(p..., t)
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
f1a(p, t) = oop(p, t)
f1a(out, p, t) = iip(out, p, t)
end
else
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1b(p::MTKParameters) = oop(p...)
f1b(out, p::MTKParameters) = iip(out, p...)
f1b(p) = oop(p)
f1b(out, p) = iip(out, p)
end
end
else
if is_time_dependent(sys)
obsfn = let rawobs = rawobs
f2a(p::MTKParameters, t) = rawobs(p..., t)
end
else
obsfn = let rawobs = rawobs
f2b(p::MTKParameters) = rawobs(p...)
end
end
obsfn = rawobs
end
else
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
Expand Down Expand Up @@ -802,17 +834,11 @@ function SymbolicIndexingInterface.observed(
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)

if is_time_dependent(sys)
return let _fn = _fn
fn1(u, p, t) = _fn(u, p, t)
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
fn1
end
return _fn
else
return let _fn = _fn
fn2(u, p) = _fn(u, p)
fn2(u, p::MTKParameters) = _fn(u, p...)
fn2(::Nothing, p) = _fn([], p)
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
fn2
end
end
Expand All @@ -828,6 +854,8 @@ end
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false

SymbolicIndexingInterface.is_markovian(sys::AbstractSystem) = !is_dde(sys)

SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true

function SymbolicIndexingInterface.all_variable_symbols(sys::AbstractSystem)
Expand Down Expand Up @@ -971,6 +999,7 @@ for prop in [:eqs
:solved_unknowns
:split_idxs
:parent
:is_dde
:index_cache
:is_scalar_noise
:isscheduled]
Expand Down Expand Up @@ -2349,8 +2378,8 @@ function linearization_function(sys::AbstractSystem, inputs,
u_getter = u_getter

function (u, p, t)
p_setter!(oldps, p_getter(u, p..., t))
newu = u_getter(u, p..., t)
p_setter!(oldps, p_getter(u, p, t))
newu = u_getter(u, p, t)
return newu, oldps
end
end
Expand All @@ -2361,20 +2390,15 @@ function linearization_function(sys::AbstractSystem, inputs,

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(state), p_getter(state)
return u_getter(
state_values(state), parameter_values(state), current_time(state)),
p_getter(state)
end
end
end
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
initprobmap = let inner = initprobmap
fn(u, p::MTKParameters) = inner(u, p...)
fn(u, p) = inner(u, p)
fn
end
end
ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
Expand Down Expand Up @@ -2421,7 +2445,7 @@ function linearization_function(sys::AbstractSystem, inputs,
fg_xz = ForwardDiff.jacobian(uf, u)
h_xz = ForwardDiff.jacobian(
let p = p, t = t
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
xz -> h(xz, p, t)
end, u)
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
Expand All @@ -2433,7 +2457,6 @@ function linearization_function(sys::AbstractSystem, inputs,
end
hp = let u = u, t = t
_hp(p) = h(u, p, t)
_hp(p::MTKParameters) = h(u, p..., t)
_hp
end
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
Expand Down Expand Up @@ -2486,7 +2509,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
dx = fun(sts, p..., t)

h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
y = h(sts, p..., t)
y = h(sts, p, t)

fg_xz = Symbolics.jacobian(dx, sts)
fg_u = Symbolics.jacobian(dx, inputs)
Expand Down Expand Up @@ -2955,6 +2978,9 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
nsys == 0 && return sys
@set! sys.name = name
@set! sys.systems = [get_systems(sys); systems]
if has_is_dde(sys)
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
end
return sys
end
function compose(syss...; name = nameof(first(syss)))
Expand Down
44 changes: 35 additions & 9 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,29 @@ struct Schedule
dummy_sub::Any
end

"""
is_dde(sys::AbstractSystem)

Return a boolean indicating whether a system represents a set of delay
differential equations.
"""
is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys)

function _check_if_dde(eqs, iv, subsystems)
is_dde = any(ModelingToolkit.is_dde, subsystems)
if !is_dde
vs = Set()
for eq in eqs
vars!(vs, eq)
is_dde = any(vs) do sym
isdelay(unwrap(sym), iv)
end
is_dde && break
end
end
return is_dde
end

function filter_kwargs(kwargs)
kwargs = Dict(kwargs)
for key in keys(kwargs)
Expand Down Expand Up @@ -219,29 +242,32 @@ function isdelay(var, iv)
return false
end
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys))
const DEFAULT_PARAMS_ARG = Sym{Any}(:ˍ₋arg3)
function delay_to_function(
sys::AbstractODESystem, eqs = full_equations(sys); history_arg = DEFAULT_PARAMS_ARG)
delay_to_function(eqs,
get_iv(sys),
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(unknowns(sys))),
parameters(sys),
DDE_HISTORY_FUN)
DDE_HISTORY_FUN; history_arg)
end
function delay_to_function(eqs::Vector, iv, sts, ps, h)
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
function delay_to_function(eqs::Vector, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,); history_arg)
end
function delay_to_function(eq::Equation, iv, sts, ps, h)
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
function delay_to_function(eq::Equation, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
delay_to_function(eq.lhs, iv, sts, ps, h; history_arg) ~ delay_to_function(
eq.rhs, iv, sts, ps, h; history_arg)
end
function delay_to_function(expr, iv, sts, ps, h)
function delay_to_function(expr, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
if isdelay(expr, iv)
v = operation(expr)
time = arguments(expr)[1]
idx = sts[v]
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
return term(getindex, h(history_arg, time), idx, type = Real) # BIG BIG HACK
elseif iscall(expr)
return maketerm(typeof(expr),
operation(expr),
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)),
map(x -> delay_to_function(x, iv, sts, ps, h; history_arg), arguments(expr)),
metadata(expr))
else
return expr
Expand Down
Loading
Loading