@@ -159,7 +159,7 @@ function MTKParameters(
159
159
end
160
160
tunable_buffer = narrow_buffer_type (tunable_buffer)
161
161
if isempty (tunable_buffer)
162
- tunable_buffer = Float64[]
162
+ tunable_buffer = SizedVector {0, Float64} ()
163
163
end
164
164
disc_buffer = broadcast .(narrow_buffer_type, disc_buffer)
165
165
const_buffer = narrow_buffer_type .(const_buffer)
@@ -668,54 +668,56 @@ function DiffEqBase.anyeltypedual(p::Type{<:MTKParameters{T}},
668
668
DiffEqBase. __anyeltypedual (T)
669
669
end
670
670
671
- _subarrays (v:: AbstractVector ) = isempty (v) ? () : (v,)
672
- _subarrays (v:: ArrayPartition ) = v. x
673
- _subarrays (v:: Tuple ) = v
674
- _num_subarrays (v:: AbstractVector ) = 1
675
- _num_subarrays (v:: ArrayPartition ) = length (v. x)
676
- _num_subarrays (v:: Tuple ) = length (v)
677
671
# for compiling callbacks
678
672
# getindex indexes the vectors, setindex! linearly indexes values
679
673
# it's inconsistent, but we need it to be this way
680
- function Base. getindex (buf:: MTKParameters , i)
681
- i_orig = i
682
- if ! isempty (buf. tunable)
683
- i == 1 && return buf. tunable
684
- i -= 1
685
- end
686
- if ! isempty (buf. discrete)
687
- for clockbuf in buf. discrete
688
- i <= _num_subarrays (clockbuf) && return _subarrays (clockbuf)[i]
689
- i -= _num_subarrays (clockbuf)
674
+ @generated function Base. getindex (
675
+ ps:: MTKParameters{T, D, C, E, N} , idx:: Int ) where {T, D, C, E, N}
676
+ paths = []
677
+ if ! (T <: SizedVector{0, Float64} )
678
+ push! (paths, :(ps. tunable))
679
+ end
680
+ for i in 1 : length (D)
681
+ for j in 1 : fieldcount (eltype (D))
682
+ push! (paths, :(ps. discrete[$ i][$ j]))
690
683
end
691
684
end
692
- if ! isempty (buf. constant)
693
- i <= _num_subarrays (buf. constant) && return _subarrays (buf. constant)[i]
694
- i -= _num_subarrays (buf. constant)
685
+ for i in 1 : fieldcount (C)
686
+ push! (paths, :(ps. constant[$ i]))
695
687
end
696
- if ! isempty (buf. nonnumeric)
697
- i <= _num_subarrays (buf. nonnumeric) && return _subarrays (buf. nonnumeric)[i]
698
- i -= _num_subarrays (buf. nonnumeric)
688
+ for i in 1 : fieldcount (E)
689
+ push! (paths, :(ps. dependent[$ i]))
699
690
end
700
- if ! isempty (buf. dependent)
701
- i <= _num_subarrays (buf. dependent) && return _subarrays (buf. dependent)[i]
702
- i -= _num_subarrays (buf. dependent)
691
+ for i in 1 : fieldcount (N)
692
+ push! (paths, :(ps. nonnumeric[$ i]))
703
693
end
704
- throw (BoundsError (buf, i_orig))
694
+ expr = Expr (:if , :(idx == 1 ), :(return $ (paths[1 ])))
695
+ curexpr = expr
696
+ for i in 2 : length (paths)
697
+ push! (curexpr. args, Expr (:elseif , :(idx == $ i), :(return $ (paths[i]))))
698
+ curexpr = curexpr. args[end ]
699
+ end
700
+ return Expr (:block , expr, :(throw (BoundsError (ps, idx))))
701
+ end
702
+
703
+ @generated function Base. length (ps:: MTKParameters{T, D, C, E, N} ) where {T, D, C, E, N}
704
+ len = 0
705
+ if ! (T <: SizedVector{0, Float64} )
706
+ len += 1
707
+ end
708
+ if length (D) > 0
709
+ len += length (D) * fieldcount (eltype (D))
710
+ end
711
+ len += fieldcount (C) + fieldcount (E) + fieldcount (N)
712
+ return len
705
713
end
706
714
707
715
Base. getindex (p:: MTKParameters , pind:: ParameterIndex ) = parameter_values (p, pind)
708
716
709
717
Base. setindex! (p:: MTKParameters , val, pind:: ParameterIndex ) = set_parameter! (p, val, pind)
710
718
711
719
function Base. iterate (buf:: MTKParameters , state = 1 )
712
- total_len = Int (! isempty (buf. tunable)) # for tunables
713
- for clockbuf in buf. discrete
714
- total_len += _num_subarrays (clockbuf)
715
- end
716
- total_len += _num_subarrays (buf. constant)
717
- total_len += _num_subarrays (buf. nonnumeric)
718
- total_len += _num_subarrays (buf. dependent)
720
+ total_len = length (buf)
719
721
if state <= total_len
720
722
return (buf[state], state + 1 )
721
723
else
0 commit comments