Skip to content

Commit 0e50dcf

Browse files
fix: fix and document wrap_mtkparameters
1 parent f8030f8 commit 0e50dcf

File tree

1 file changed

+46
-20
lines changed

1 file changed

+46
-20
lines changed

src/systems/abstractsystem.jl

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -398,50 +398,76 @@ function wrap_array_vars(
398398
end
399399
end
400400

401-
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
401+
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)
402+
403+
"""
404+
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
405+
406+
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
407+
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
408+
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
409+
wrapped is for a scalar value. `p_start` is the index of the argument containing
410+
the first parameter vector in the out-of-place version of the function. For example,
411+
if a history function (DDEs) was passed before `p`, then the function before wrapping
412+
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
413+
414+
The returned function is `identity` if the system does not have an `IndexCache`.
415+
"""
416+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
402417
if has_index_cache(sys) && get_index_cache(sys) !== nothing
403418
offset = Int(is_time_dependent(sys))
404419

405420
if isscalar
406421
function (expr)
407-
p = gensym(:p)
422+
param_args = expr.args[p_start:(end - offset)]
423+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
424+
param_buffer_args = param_args[param_buffer_idxs]
425+
destructured_mtkparams = DestructuredArgs(
426+
[x.name for x in param_buffer_args],
427+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
408428
Func(
409429
[
410-
expr.args[1],
411-
DestructuredArgs(
412-
[arg.name for arg in expr.args[2:(end - offset)]], p),
413-
(isone(offset) ? (expr.args[end],) : ())...
430+
expr.args[begin:(p_start - 1)]...,
431+
destructured_mtkparams,
432+
expr.args[(end - offset + 1):end]...
414433
],
415434
[],
416-
Let(expr.args[2:(end - offset)], expr.body, false)
435+
Let(param_buffer_args, expr.body, false)
417436
)
418437
end
419438
else
420439
function (expr)
421-
p = gensym(:p)
440+
param_args = expr.args[p_start:(end - offset)]
441+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
442+
param_buffer_args = param_args[param_buffer_idxs]
443+
destructured_mtkparams = DestructuredArgs(
444+
[x.name for x in param_buffer_args],
445+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
422446
Func(
423447
[
424-
expr.args[1],
425-
DestructuredArgs(
426-
[arg.name for arg in expr.args[2:(end - offset)]], p),
427-
(isone(offset) ? (expr.args[end],) : ())...
448+
expr.args[begin:(p_start - 1)]...,
449+
destructured_mtkparams,
450+
expr.args[(end - offset + 1):end]...
428451
],
429452
[],
430-
Let(expr.args[2:(end - offset)], expr.body, false)
453+
Let(param_buffer_args, expr.body, false)
431454
)
432455
end,
433456
function (expr)
434-
p = gensym(:p)
457+
param_args = expr.args[(p_start + 1):(end - offset)]
458+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
459+
param_buffer_args = param_args[param_buffer_idxs]
460+
destructured_mtkparams = DestructuredArgs(
461+
[x.name for x in param_buffer_args],
462+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
435463
Func(
436464
[
437-
expr.args[1],
438-
expr.args[2],
439-
DestructuredArgs(
440-
[arg.name for arg in expr.args[3:(end - offset)]], p),
441-
(isone(offset) ? (expr.args[end],) : ())...
465+
expr.args[begin:p_start]...,
466+
destructured_mtkparams,
467+
expr.args[(end - offset + 1):end]...
442468
],
443469
[],
444-
Let(expr.args[3:(end - offset)], expr.body, false)
470+
Let(param_buffer_args, expr.body, false)
445471
)
446472
end
447473
end

0 commit comments

Comments
 (0)