@@ -230,3 +230,59 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
230230 end
231231 return build_function (expr, args... ; wrap_code, similarto, kwargs... )
232232end
233+
234+ """
235+ $(TYPEDEF)
236+
237+ A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
238+ must be a 3-tuple where the first element is the index of the parameter object in the
239+ arguments, the second is the expected number of arguments in the out-of-place variant
240+ of the function, and the third is a boolean indicating whether the generated functions
241+ are for a split system. For scalar functions, the inplace variant can be `nothing`.
242+ """
243+ struct GeneratedFunctionWrapper{P, O, I} <: Function
244+ f_oop:: O
245+ f_iip:: I
246+ end
247+
248+ function GeneratedFunctionWrapper {P} (foop:: O , fiip:: I ) where {P, O, I}
249+ GeneratedFunctionWrapper {P, O, I} (foop, fiip)
250+ end
251+
252+ function (gfw:: GeneratedFunctionWrapper )(args... )
253+ _generated_call (gfw, args... )
254+ end
255+
256+ @generated function _generated_call (gfw:: GeneratedFunctionWrapper{P} , args... ) where {P}
257+ paramidx, nargs, issplit = P
258+ iip = false
259+ # IIP case has one more argument
260+ if length (args) == nargs + 1
261+ nargs += 1
262+ paramidx += 1
263+ iip = true
264+ end
265+ if length (args) != nargs
266+ throw (ArgumentError (" Expected $nargs arguments, got $(length (args)) ." ))
267+ end
268+
269+ # the function to use
270+ f = iip ? :(gfw. f_iip) : :(gfw. f_oop)
271+ # non-split systems just call it as-is
272+ if ! issplit
273+ return :($ f (args... ))
274+ end
275+ if args[paramidx] <: Union{Tuple, MTKParameters} &&
276+ ! (args[paramidx] <: Tuple{Vararg{Number}} )
277+ # for split systems, call it as-is if the parameter object is a tuple or MTKParameters
278+ # but not if it is a tuple of numbers
279+ return :($ f (args... ))
280+ else
281+ # The user provided a single buffer/tuple for the parameter object, so wrap that
282+ # one in a tuple
283+ fargs = ntuple (Val (length (args))) do i
284+ i == paramidx ? :((args[$ i],)) : :(args[$ i])
285+ end
286+ return :($ f ($ (fargs... )))
287+ end
288+ end
0 commit comments