@@ -496,7 +496,119 @@ end
496496function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals)
497497 _remake_buffer (indp, oldbuf, idxs, vals)
498498end
499+
500+ # For type-inference when using `SII.setp_oop`
501+ @generated function _remake_buffer (
502+ indp, oldbuf:: MTKParameters{T, I, D, C, N, H} ,
503+ idxs:: Union{Tuple{Vararg{ParameterIndex}}, AbstractArray{<:ParameterIndex{P}}} ,
504+ vals:: Union{AbstractArray, Tuple} ; validate = true ) where {T, I, D, C, N, H, P}
505+
506+ # fallback to non-generated method if values aren't type-stable
507+ if vals <: AbstractArray && ! isconcretetype (eltype (vals))
508+ return quote
509+ $ __remake_buffer (indp, oldbuf, collect (idxs), vals; validate)
510+ end
511+ end
512+
513+ # given an index in idxs/vals and the current `eltype` of the buffer,
514+ # return the promoted eltype of the buffer
515+ function promote_valtype (i, valT)
516+ # tuples have distinct types, arrays have a common eltype
517+ valT′ = vals <: AbstractArray ? eltype (vals) : fieldtype (vals, i)
518+ # if the buffer is a scalarized buffer but the variable is an array
519+ # e.g. an array tunable, take the eltype
520+ if valT′ <: AbstractArray && ! (valT <: AbstractArray )
521+ valT′ = eltype (valT′)
522+ end
523+ return promote_type (valT, valT′)
524+ end
525+
526+ # types of the idxs
527+ idxtypes = if idxs <: AbstractArray
528+ # if both are arrays, there is only one possible type to check
529+ if vals <: AbstractArray
530+ (eltype (idxs),)
531+ else
532+ # if `vals` is a tuple, we repeat `eltype(idxs)` to check against
533+ # every possible type of the buffer
534+ ntuple (Returns (eltype (idxs)), Val (fieldcount (vals)))
535+ end
536+ else
537+ # `idxs` is a tuple, so we check against all buffers
538+ fieldtypes (idxs)
539+ end
540+ # promote types
541+ tunablesT = eltype (T)
542+ for (i, idxT) in enumerate (idxtypes)
543+ idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
544+ tunablesT = promote_valtype (i, tunablesT)
545+ end
546+ initialsT = eltype (I)
547+ for (i, idxT) in enumerate (idxtypes)
548+ idxT <: ParameterIndex{SciMLStructures.Initials} || continue
549+ initialsT = promote_valtype (i, initialsT)
550+ end
551+ discretesT = ntuple (Val (fieldcount (D))) do i
552+ bufT = eltype (fieldtype (D, i))
553+ for (j, idxT) in enumerate (idxtypes)
554+ idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
555+ bufT = promote_valtype (i, bufT)
556+ end
557+ bufT
558+ end
559+ constantsT = ntuple (Val (fieldcount (C))) do i
560+ bufT = eltype (fieldtype (C, i))
561+ for (j, idxT) in enumerate (idxtypes)
562+ idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
563+ bufT = promote_valtype (i, bufT)
564+ end
565+ bufT
566+ end
567+ nonnumericT = ntuple (Val (fieldcount (N))) do i
568+ bufT = eltype (fieldtype (N, i))
569+ for (j, idxT) in enumerate (idxtypes)
570+ idxT <: ParameterIndex{Nonnumeric, i} || continue
571+ bufT = promote_valtype (i, bufT)
572+ end
573+ bufT
574+ end
575+
576+ expr = quote
577+ tunables = $ similar (oldbuf. tunable, $ tunablesT)
578+ copyto! (tunables, oldbuf. tunable)
579+ initials = $ similar (oldbuf. initials, $ initialsT)
580+ copyto! (initials, oldbuf. initials)
581+ discretes = $ (Expr (:tuple ,
582+ (:($ similar (oldbuf. discrete[$ i], $ (discretesT[i]))) for i in 1 : length (discretesT)). .. ))
583+ $ ((:($ copyto! (discretes[$ i], oldbuf. discrete[$ i])) for i in 1 : length (discretesT)). .. )
584+ constants = $ (Expr (:tuple ,
585+ (:($ similar (oldbuf. constant[$ i], $ (constantsT[i]))) for i in 1 : length (constantsT)). .. ))
586+ $ ((:($ copyto! (constants[$ i], oldbuf. constant[$ i])) for i in 1 : length (constantsT)). .. )
587+ nonnumerics = $ (Expr (:tuple ,
588+ (:($ similar (oldbuf. nonnumeric[$ i], $ (nonnumericT[i]))) for i in 1 : length (nonnumericT)). .. ))
589+ $ ((:($ copyto! (nonnumerics[$ i], oldbuf. nonnumeric[$ i])) for i in 1 : length (nonnumericT)). .. )
590+ newbuf = MTKParameters (
591+ tunables, initials, discretes, constants, nonnumerics, copy .(oldbuf. caches))
592+ end
593+ if idxs <: AbstractArray
594+ push! (expr. args, :(for (idx, val) in zip (idxs, vals)
595+ $ setindex! (newbuf, val, idx)
596+ end ))
597+ else
598+ for i in 1 : fieldcount (idxs)
599+ push! (expr. args, :($ setindex! (newbuf, vals[$ i], idxs[$ i])))
600+ end
601+ end
602+ push! (expr. args, :(return newbuf))
603+
604+ return expr
605+ end
606+
499607function _remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals; validate = true )
608+ return __remake_buffer (indp, oldbuf, idxs, vals; validate)
609+ end
610+
611+ function __remake_buffer (indp, oldbuf:: MTKParameters , idxs, vals; validate = true )
500612 newbuf = @set oldbuf. tunable = similar (oldbuf. tunable, Any)
501613 @set! newbuf. initials = similar (oldbuf. initials, Any)
502614 @set! newbuf. discrete = Tuple (similar (buf, Any) for buf in newbuf. discrete)
0 commit comments