Skip to content

Commit 7dabbf9

Browse files
feat: add @generated method for _remake_buffer for type-inference
1 parent bce9beb commit 7dabbf9

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

src/systems/parameter_buffer.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,66 @@ end
496496
function SymbolicIndexingInterface.remake_buffer(indp, oldbuf::MTKParameters, idxs, vals)
497497
_remake_buffer(indp, oldbuf, idxs, vals)
498498
end
499+
500+
# For type-inference when using `SII.setp_oop`
501+
@generated function _remake_buffer(indp, oldbuf::MTKParameters{T, I, D, C, N, H}, idxs::Tuple{Vararg{ParameterIndex}}, vals::Union{AbstractArray, Tuple}; validate = true) where {T, I, D, C, N, H}
502+
valtype(i) = vals <: AbstractArray ? eltype(vals) : fieldtype(vals, i)
503+
tunablesT = eltype(T)
504+
for (i, idxT) in enumerate(fieldtypes(idxs))
505+
idxT <: ParameterIndex{SciMLStructures.Tunable} || continue
506+
tunablesT = promote_type(tunablesT, valtype(i))
507+
end
508+
initialsT = eltype(I)
509+
for (i, idxT) in enumerate(fieldtypes(idxs))
510+
idxT <: ParameterIndex{SciMLStructures.Initials} || continue
511+
initialsT = promote_type(initialsT, valtype(i))
512+
end
513+
discretesT = ntuple(Val(fieldcount(D))) do i
514+
bufT = eltype(fieldtype(D, i))
515+
for (j, idxT) in enumerate(fieldtypes(idxs))
516+
idxT <: ParameterIndex{SciMLStructures.Discrete, i} || continue
517+
bufT = promote_type(bufT, valtype(i))
518+
end
519+
bufT
520+
end
521+
constantsT = ntuple(Val(fieldcount(C))) do i
522+
bufT = eltype(fieldtype(C, i))
523+
for (j, idxT) in enumerate(fieldtypes(idxs))
524+
idxT <: ParameterIndex{SciMLStructures.Constants, i} || continue
525+
bufT = promote_type(bufT, valtype(i))
526+
end
527+
bufT
528+
end
529+
nonnumericT = ntuple(Val(fieldcount(N))) do i
530+
bufT = eltype(fieldtype(N, i))
531+
for (j, idxT) in enumerate(fieldtypes(idxs))
532+
idxT <: ParameterIndex{Nonnumeric, i} || continue
533+
bufT = promote_type(bufT, valtype(i))
534+
end
535+
bufT
536+
end
537+
538+
expr = quote
539+
tunables = $similar(oldbuf.tunable, $tunablesT)
540+
copyto!(tunables, oldbuf.tunable)
541+
initials = $similar(oldbuf.initials, $initialsT)
542+
copyto!(initials, oldbuf.initials)
543+
discretes = $(Expr(:tuple, (:($similar(oldbuf.discrete[$i], $(discretesT[i]))) for i in 1:length(discretesT))...))
544+
$((:($copyto!(discretes[$i], oldbuf.discrete[$i])) for i in 1:length(discretesT))...)
545+
constants = $(Expr(:tuple, (:($similar(oldbuf.constant[$i], $(constantsT[i]))) for i in 1:length(constantsT))...))
546+
$((:($copyto!(constants[$i], oldbuf.constant[$i])) for i in 1:length(constantsT))...)
547+
nonnumerics = $(Expr(:tuple, (:($similar(oldbuf.nonnumeric[$i], $(nonnumericT[i]))) for i in 1:length(nonnumericT))...))
548+
$((:($copyto!(nonnumerics[$i], oldbuf.nonnumeric[$i])) for i in 1:length(nonnumericT))...)
549+
newbuf = MTKParameters(tunables, initials, discretes, constants, nonnumerics, copy.(oldbuf.caches))
550+
end
551+
for i in 1:fieldcount(idxs)
552+
push!(expr.args, :($setindex!(newbuf, vals[$i], idxs[$i])))
553+
end
554+
push!(expr.args, :(return newbuf))
555+
556+
return expr
557+
end
558+
499559
function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true)
500560
newbuf = @set oldbuf.tunable = similar(oldbuf.tunable, Any)
501561
@set! newbuf.initials = similar(oldbuf.initials, Any)

0 commit comments

Comments
 (0)