Skip to content

Commit 9328836

Browse files
fix: fix and document wrap_mtkparameters
1 parent a457a02 commit 9328836

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

src/systems/abstractsystem.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -398,50 +398,64 @@ function wrap_array_vars(
398398
end
399399
end
400400

401-
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
401+
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)
402+
403+
"""
404+
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
405+
406+
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
407+
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
408+
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
409+
wrapped is for a scalar value. `p_start` is the index of the argument containing
410+
the first parameter vector in the out-of-place version of the function. For example,
411+
if a history function (DDEs) was passed before `p`, then the function before wrapping
412+
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
413+
414+
The returned function is `identity` if the system does not have an `IndexCache`.
415+
"""
416+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
402417
if has_index_cache(sys) && get_index_cache(sys) !== nothing
403418
offset = Int(is_time_dependent(sys))
404419

405420
if isscalar
406421
function (expr)
407-
p = gensym(:p)
422+
destructured_args = [x for x in expr.args[p_start:end-offset] if x isa DestructuredArgs]
408423
Func(
409424
[
410-
expr.args[1],
425+
expr.args[begin:p_start-1]...,
411426
DestructuredArgs(
412-
[arg.name for arg in expr.args[2:(end - offset)]], p),
413-
(isone(offset) ? (expr.args[end],) : ())...
427+
[arg.name for arg in expr.args[p_start:(end - offset)]], MTKPARAMETERS_ARG),
428+
expr.args[end-offset+1:end]...
414429
],
415430
[],
416-
Let(expr.args[2:(end - offset)], expr.body, false)
431+
Let(destructured_args, expr.body, false)
417432
)
418433
end
419434
else
420435
function (expr)
421-
p = gensym(:p)
436+
destructured_args = [x for x in expr.args[p_start:end-offset] if x isa DestructuredArgs]
422437
Func(
423438
[
424-
expr.args[1],
439+
expr.args[begin:p_start-1]...,
425440
DestructuredArgs(
426-
[arg.name for arg in expr.args[2:(end - offset)]], p),
427-
(isone(offset) ? (expr.args[end],) : ())...
441+
[arg.name for arg in expr.args[p_start:(end - offset)]], MTKPARAMETERS_ARG),
442+
expr.args[end-offset+1:end]...
428443
],
429444
[],
430-
Let(expr.args[2:(end - offset)], expr.body, false)
445+
Let(destructured_args, expr.body, false)
431446
)
432447
end,
433448
function (expr)
434-
p = gensym(:p)
449+
destructured_args = [x for x in expr.args[p_start:end-offset] if x isa DestructuredArgs]
435450
Func(
436451
[
437-
expr.args[1],
438-
expr.args[2],
452+
expr.args[begin:p_start]...,
439453
DestructuredArgs(
440-
[arg.name for arg in expr.args[3:(end - offset)]], p),
441-
(isone(offset) ? (expr.args[end],) : ())...
454+
[arg.name for arg in expr.args[p_start+1:(end - offset)]], MTKPARAMETERS_ARG),
455+
expr.args[end-offset+1:end]...
442456
],
443457
[],
444-
Let(expr.args[3:(end - offset)], expr.body, false)
458+
Let(destructured_args, expr.body, false)
445459
)
446460
end
447461
end

0 commit comments

Comments
 (0)