@@ -499,40 +499,76 @@ end
499499
500500# For type-inference when using `SII.setp_oop`
501501@generated function _remake_buffer (
502- indp, oldbuf:: MTKParameters{T, I, D, C, N, H} , idxs:: Tuple{Vararg{ParameterIndex}} ,
503- vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H}
504- valtype (i) = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
502+ indp, oldbuf:: MTKParameters{T, I, D, C, N, H} ,
503+ idxs:: Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}} ,
504+ vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H, P}
505+
506+ # fallback to non-generated method if values aren't type-stable
507+ if vals <: AbstractArray && ! isconcretetype (eltype (vals))
508+ return quote
509+ $ _remake_buffer (indp, oldbuf, collect (idxs), vals; validate)
510+ end
511+ end
512+
513+ # given an index in idxs/vals and the current `eltype` of the buffer,
514+ # return the promoted eltype of the buffer
515+ function promote_valtype (i, valT)
516+ # tuples have distinct types, arrays have a common eltype
517+ valT′ = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
518+ # if the buffer is a scalarized buffer but the variable is an array
519+ # e.g. an array tunable, take the eltype
520+ if valT′ <: AbstractArray && ! (valT <: AbstractArray )
521+ valT′ = eltype (valT′)
522+ end
523+ return promote_type (valT, valT′)
524+ end
525+
526+ # types of the idxs
527+ idxtypes = if idxs <: AbstractArray
528+ # if both are arrays, there is only one possible type to check
529+ if vals <: AbstractArray
530+ (eltype (idxs),)
531+ else
532+ # if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533+ # every possible type of the buffer
534+ ntuple (Returns (eltype (idxs)), Val (fieldcount (vals)))
535+ end
536+ else
537+ # `idxs` is a tuple, so we check against all buffers
538+ fieldtypes (idxs)
539+ end
540+ # promote types
505541 tunablesT = eltype (T)
506- for (i, idxT) in enumerate (fieldtypes (idxs) )
542+ for (i, idxT) in enumerate (idxtypes )
507543 idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
508- tunablesT = promote_type (tunablesT, valtype (i) )
544+ tunablesT = promote_valtype (i, tunablesT )
509545 end
510546 initialsT = eltype (I)
511- for (i, idxT) in enumerate (fieldtypes (idxs) )
547+ for (i, idxT) in enumerate (idxtypes )
512548 idxT <: ParameterIndex{SciMLStructures.Initials} || continue
513- initialsT = promote_type (initialsT, valtype (i) )
549+ initialsT = promote_valtype (i, initialsT )
514550 end
515551 discretesT = ntuple (Val (fieldcount (D))) do i
516552 bufT = eltype (fieldtype (D, i))
517- for (j, idxT) in enumerate (fieldtypes (idxs) )
553+ for (j, idxT) in enumerate (idxtypes )
518554 idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
519- bufT = promote_type (bufT, valtype (i) )
555+ bufT = promote_valtype (i, bufT )
520556 end
521557 bufT
522558 end
523559 constantsT = ntuple (Val (fieldcount (C))) do i
524560 bufT = eltype (fieldtype (C, i))
525- for (j, idxT) in enumerate (fieldtypes (idxs) )
561+ for (j, idxT) in enumerate (idxtypes )
526562 idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
527- bufT = promote_type (bufT, valtype (i) )
563+ bufT = promote_valtype (i, bufT )
528564 end
529565 bufT
530566 end
531567 nonnumericT = ntuple (Val (fieldcount (N))) do i
532568 bufT = eltype (fieldtype (N, i))
533- for (j, idxT) in enumerate (fieldtypes (idxs) )
569+ for (j, idxT) in enumerate (idxtypes )
534570 idxT <: ParameterIndex{Nonnumeric, i} || continue
535- bufT = promote_type (bufT, valtype (i) )
571+ bufT = promote_valtype (i, bufT )
536572 end
537573 bufT
538574 end
554590 newbuf = MTKParameters (
555591 tunables, initials, discretes, constants, nonnumerics, copy .(oldbuf. caches))
556592 end
557- for i in 1 : fieldcount (idxs)
558- push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
593+ if idxs <: AbstractArray
594+ push! (expr. args, :(for (idx, val) in zip (idxs, vals)
595+ $ setindex! (newbuf, val, idx)
596+ end ))
597+ else
598+ for i in 1 : fieldcount (idxs)
599+ push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
600+ end
559601 end
560602 push! (expr. args, :(return newbuf))
561603
0 commit comments