@@ -398,50 +398,76 @@ function wrap_array_vars(
398398 end
399399end
400400
401- function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool )
401+ const MTKPARAMETERS_ARG = Sym {Vector{Vector}} (:___mtkparameters___ )
402+
403+ """
404+ wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
405+
406+ Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
407+ allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
408+ instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
409+ wrapped is for a scalar value. `p_start` is the index of the argument containing
410+ the first parameter vector in the out-of-place version of the function. For example,
411+ if a history function (DDEs) was passed before `p`, then the function before wrapping
412+ would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
413+
414+ The returned function is `identity` if the system does not have an `IndexCache`.
415+ """
416+ function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool , p_start = 2 )
402417 if has_index_cache (sys) && get_index_cache (sys) != = nothing
403418 offset = Int (is_time_dependent (sys))
404419
405420 if isscalar
406421 function (expr)
407- p = gensym (:p )
422+ param_args = expr. args[p_start: (end - offset)]
423+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
424+ param_buffer_args = param_args[param_buffer_idxs]
425+ destructured_mtkparams = DestructuredArgs (
426+ [x. name for x in param_buffer_args],
427+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
408428 Func (
409429 [
410- expr. args[1 ],
411- DestructuredArgs (
412- [arg. name for arg in expr. args[2 : (end - offset)]], p),
413- (isone (offset) ? (expr. args[end ],) : ()). ..
430+ expr. args[begin : (p_start - 1 )]. .. ,
431+ destructured_mtkparams,
432+ expr. args[(end - offset + 1 ): end ]. ..
414433 ],
415434 [],
416- Let (expr . args[ 2 : ( end - offset)] , expr. body, false )
435+ Let (param_buffer_args , expr. body, false )
417436 )
418437 end
419438 else
420439 function (expr)
421- p = gensym (:p )
440+ param_args = expr. args[p_start: (end - offset)]
441+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
442+ param_buffer_args = param_args[param_buffer_idxs]
443+ destructured_mtkparams = DestructuredArgs (
444+ [x. name for x in param_buffer_args],
445+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
422446 Func (
423447 [
424- expr. args[1 ],
425- DestructuredArgs (
426- [arg. name for arg in expr. args[2 : (end - offset)]], p),
427- (isone (offset) ? (expr. args[end ],) : ()). ..
448+ expr. args[begin : (p_start - 1 )]. .. ,
449+ destructured_mtkparams,
450+ expr. args[(end - offset + 1 ): end ]. ..
428451 ],
429452 [],
430- Let (expr . args[ 2 : ( end - offset)] , expr. body, false )
453+ Let (param_buffer_args , expr. body, false )
431454 )
432455 end ,
433456 function (expr)
434- p = gensym (:p )
457+ param_args = expr. args[(p_start + 1 ): (end - offset)]
458+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
459+ param_buffer_args = param_args[param_buffer_idxs]
460+ destructured_mtkparams = DestructuredArgs (
461+ [x. name for x in param_buffer_args],
462+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
435463 Func (
436464 [
437- expr. args[1 ],
438- expr. args[2 ],
439- DestructuredArgs (
440- [arg. name for arg in expr. args[3 : (end - offset)]], p),
441- (isone (offset) ? (expr. args[end ],) : ()). ..
465+ expr. args[begin : p_start]. .. ,
466+ destructured_mtkparams,
467+ expr. args[(end - offset + 1 ): end ]. ..
442468 ],
443469 [],
444- Let (expr . args[ 3 : ( end - offset)] , expr. body, false )
470+ Let (param_buffer_args , expr. body, false )
445471 )
446472 end
447473 end
0 commit comments