Skip to content

Commit 2b23652

Browse files
refactor: better getindex and length for MTKParameters
1 parent f833282 commit 2b23652

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

src/systems/parameter_buffer.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ function MTKParameters(
159159
end
160160
tunable_buffer = narrow_buffer_type(tunable_buffer)
161161
if isempty(tunable_buffer)
162-
tunable_buffer = Float64[]
162+
tunable_buffer = SizedVector{0, Float64}()
163163
end
164164
disc_buffer = broadcast.(narrow_buffer_type, disc_buffer)
165165
const_buffer = narrow_buffer_type.(const_buffer)
@@ -668,54 +668,56 @@ function DiffEqBase.anyeltypedual(p::Type{<:MTKParameters{T}},
668668
DiffEqBase.__anyeltypedual(T)
669669
end
670670

671-
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
672-
_subarrays(v::ArrayPartition) = v.x
673-
_subarrays(v::Tuple) = v
674-
_num_subarrays(v::AbstractVector) = 1
675-
_num_subarrays(v::ArrayPartition) = length(v.x)
676-
_num_subarrays(v::Tuple) = length(v)
677671
# for compiling callbacks
678672
# getindex indexes the vectors, setindex! linearly indexes values
679673
# it's inconsistent, but we need it to be this way
680-
function Base.getindex(buf::MTKParameters, i)
681-
i_orig = i
682-
if !isempty(buf.tunable)
683-
i == 1 && return buf.tunable
684-
i -= 1
685-
end
686-
if !isempty(buf.discrete)
687-
for clockbuf in buf.discrete
688-
i <= _num_subarrays(clockbuf) && return _subarrays(clockbuf)[i]
689-
i -= _num_subarrays(clockbuf)
674+
@generated function Base.getindex(
675+
ps::MTKParameters{T, D, C, E, N}, idx::Int) where {T, D, C, E, N}
676+
paths = []
677+
if !(T <: SizedVector{0, Float64})
678+
push!(paths, :(ps.tunable))
679+
end
680+
for i in 1:length(D)
681+
for j in 1:fieldcount(eltype(D))
682+
push!(paths, :(ps.discrete[$i][$j]))
690683
end
691684
end
692-
if !isempty(buf.constant)
693-
i <= _num_subarrays(buf.constant) && return _subarrays(buf.constant)[i]
694-
i -= _num_subarrays(buf.constant)
685+
for i in 1:fieldcount(C)
686+
push!(paths, :(ps.constant[$i]))
695687
end
696-
if !isempty(buf.nonnumeric)
697-
i <= _num_subarrays(buf.nonnumeric) && return _subarrays(buf.nonnumeric)[i]
698-
i -= _num_subarrays(buf.nonnumeric)
688+
for i in 1:fieldcount(E)
689+
push!(paths, :(ps.dependent[$i]))
699690
end
700-
if !isempty(buf.dependent)
701-
i <= _num_subarrays(buf.dependent) && return _subarrays(buf.dependent)[i]
702-
i -= _num_subarrays(buf.dependent)
691+
for i in 1:fieldcount(N)
692+
push!(paths, :(ps.nonnumeric[$i]))
703693
end
704-
throw(BoundsError(buf, i_orig))
694+
expr = Expr(:if, :(idx == 1), :(return $(paths[1])))
695+
curexpr = expr
696+
for i in 2:length(paths)
697+
push!(curexpr.args, Expr(:elseif, :(idx == $i), :(return $(paths[i]))))
698+
curexpr = curexpr.args[end]
699+
end
700+
return Expr(:block, expr, :(throw(BoundsError(ps, idx))))
701+
end
702+
703+
@generated function Base.length(ps::MTKParameters{T, D, C, E, N}) where {T, D, C, E, N}
704+
len = 0
705+
if !(T <: SizedVector{0, Float64})
706+
len += 1
707+
end
708+
if length(D) > 0
709+
len += length(D) * fieldcount(eltype(D))
710+
end
711+
len += fieldcount(C) + fieldcount(E) + fieldcount(N)
712+
return len
705713
end
706714

707715
Base.getindex(p::MTKParameters, pind::ParameterIndex) = parameter_values(p, pind)
708716

709717
Base.setindex!(p::MTKParameters, val, pind::ParameterIndex) = set_parameter!(p, val, pind)
710718

711719
function Base.iterate(buf::MTKParameters, state = 1)
712-
total_len = Int(!isempty(buf.tunable)) # for tunables
713-
for clockbuf in buf.discrete
714-
total_len += _num_subarrays(clockbuf)
715-
end
716-
total_len += _num_subarrays(buf.constant)
717-
total_len += _num_subarrays(buf.nonnumeric)
718-
total_len += _num_subarrays(buf.dependent)
720+
total_len = length(buf)
719721
if state <= total_len
720722
return (buf[state], state + 1)
721723
else

0 commit comments

Comments
 (0)