@@ -499,40 +499,76 @@ end
499499
500500#  For type-inference when using `SII.setp_oop`
501501@generated  function  _remake_buffer (
502-         indp, oldbuf:: MTKParameters{T, I, D, C, N, H} , idxs:: Tuple{Vararg{ParameterIndex}} ,
503-         vals:: Union{AbstractArray, Tuple} ; validate =  true ) where  {T, I, D, C, N, H}
504-     valtype (i) =  vals <:  AbstractArray  ?  eltype (vals) :  fieldtype (vals, i)
502+         indp, oldbuf:: MTKParameters{T, I, D, C, N, H} ,
503+         idxs:: Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}} ,
504+         vals:: Union{AbstractArray, Tuple} ; validate =  true ) where  {T, I, D, C, N, H, P}
505+ 
506+     #  fallback to non-generated method if values aren't type-stable
507+     if  vals <:  AbstractArray  &&  ! isconcretetype (eltype (vals))
508+         return  quote 
509+             $ _remake_buffer (indp, oldbuf, collect (idxs), vals; validate)
510+         end 
511+     end 
512+ 
513+     #  given an index in idxs/vals and the current `eltype` of the buffer,
514+     #  return the promoted eltype of the buffer
515+     function  promote_valtype (i, valT)
516+         #  tuples have distinct types, arrays have a common eltype
517+         valT′ =  vals <:  AbstractArray  ?  eltype (vals) :  fieldtype (vals, i)
518+         #  if the buffer is a scalarized buffer but the variable is an array
519+         #  e.g. an array tunable, take the eltype
520+         if  valT′ <:  AbstractArray  &&  ! (valT <:  AbstractArray )
521+             valT′ =  eltype (valT′)
522+         end 
523+         return  promote_type (valT, valT′)
524+     end 
525+ 
526+     #  types of the idxs
527+     idxtypes =  if  idxs <:  AbstractArray 
528+         #  if both are arrays, there is only one possible type to check
529+         if  vals <:  AbstractArray 
530+             (eltype (idxs),)
531+         else 
532+             #  if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533+             #  every possible type of the buffer
534+             ntuple (Returns (eltype (idxs)), Val (fieldcount (vals)))
535+         end 
536+     else 
537+         #  `idxs` is a tuple, so we check against all buffers
538+         fieldtypes (idxs)
539+     end 
540+     #  promote types
505541    tunablesT =  eltype (T)
506-     for  (i, idxT) in  enumerate (fieldtypes (idxs) )
542+     for  (i, idxT) in  enumerate (idxtypes )
507543        idxT <:  ParameterIndex{SciMLStructures.Tunable}  ||  continue 
508-         tunablesT =  promote_type (tunablesT,  valtype (i) )
544+         tunablesT =  promote_valtype (i, tunablesT )
509545    end 
510546    initialsT =  eltype (I)
511-     for  (i, idxT) in  enumerate (fieldtypes (idxs) )
547+     for  (i, idxT) in  enumerate (idxtypes )
512548        idxT <:  ParameterIndex{SciMLStructures.Initials}  ||  continue 
513-         initialsT =  promote_type (initialsT,  valtype (i) )
549+         initialsT =  promote_valtype (i, initialsT )
514550    end 
515551    discretesT =  ntuple (Val (fieldcount (D))) do  i
516552        bufT =  eltype (fieldtype (D, i))
517-         for  (j, idxT) in  enumerate (fieldtypes (idxs) )
553+         for  (j, idxT) in  enumerate (idxtypes )
518554            idxT <:  ParameterIndex{SciMLStructures.Discrete, i}  ||  continue 
519-             bufT =  promote_type (bufT,  valtype (i) )
555+             bufT =  promote_valtype (i, bufT )
520556        end 
521557        bufT
522558    end 
523559    constantsT =  ntuple (Val (fieldcount (C))) do  i
524560        bufT =  eltype (fieldtype (C, i))
525-         for  (j, idxT) in  enumerate (fieldtypes (idxs) )
561+         for  (j, idxT) in  enumerate (idxtypes )
526562            idxT <:  ParameterIndex{SciMLStructures.Constants, i}  ||  continue 
527-             bufT =  promote_type (bufT,  valtype (i) )
563+             bufT =  promote_valtype (i, bufT )
528564        end 
529565        bufT
530566    end 
531567    nonnumericT =  ntuple (Val (fieldcount (N))) do  i
532568        bufT =  eltype (fieldtype (N, i))
533-         for  (j, idxT) in  enumerate (fieldtypes (idxs) )
569+         for  (j, idxT) in  enumerate (idxtypes )
534570            idxT <:  ParameterIndex{Nonnumeric, i}  ||  continue 
535-             bufT =  promote_type (bufT,  valtype (i) )
571+             bufT =  promote_valtype (i, bufT )
536572        end 
537573        bufT
538574    end 
554590        newbuf =  MTKParameters (
555591            tunables, initials, discretes, constants, nonnumerics, copy .(oldbuf. caches))
556592    end 
557-     for  i in  1 : fieldcount (idxs)
558-         push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
593+     if  idxs <:  AbstractArray 
594+         push! (expr. args, :(for  (idx, val) in  zip (idxs, vals)
595+             $ setindex! (newbuf, val, idx)
596+         end ))
597+     else 
598+         for  i in  1 : fieldcount (idxs)
599+             push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
600+         end 
559601    end 
560602    push! (expr. args, :(return  newbuf))
561603
0 commit comments