@@ -22,16 +22,6 @@ function Base.last(x::AbstractVector, n::StaticInt)
22
22
@inbounds x[max (offset1 (x), (stop + one (stop)) - n): stop]
23
23
end
24
24
25
- function _is_splat (:: Type{I} , i:: StaticInt ) where {I}
26
- if dynamic (is_splat_index (field_type (I, i)))
27
- True ()
28
- else
29
- False ()
30
- end
31
- end
32
-
33
- _ndims_index (:: Type{I} , i:: StaticInt ) where {I} = StaticInt (ndims_index (field_type (I, i)))
34
-
35
25
"""
36
26
to_indices(A, I::Tuple) -> Tuple
37
27
@@ -91,69 +81,92 @@ This implementation differs from that of `Base.to_indices` in the following ways
91
81
"""
92
82
to_indices (A, :: Tuple{} ) = ()
93
83
@inline function to_indices (a:: A , inds:: I ) where {A,I}
94
- _to_indices (
95
- a,
96
- inds,
97
- IndexStyle (A),
98
- static (ndims (A)),
99
- eachop (_ndims_index, ntuple (static, StaticInt (known_length (I))), I),
100
- eachop (_is_splat, ntuple (static, StaticInt (known_length (I))), I)
101
- )
102
- end
103
- @generated function _to_indices (A, inds:: I , :: S , :: StaticInt{N} , :: NDI , :: IS ) where {I,S,N,NDI,IS}
104
- cnt = zeros (Int, known_length (NDI))
105
- splat_position = 0
106
- remaining = N
107
- for i in 1 : known_length (NDI)
108
- ndi = known (NDI. parameters[i])
109
- splat = known (IS. parameters[i])
110
- if splat && splat_position === 0
111
- splat_position = i
112
- else
113
- remaining -= ndi
114
- cnt[i] = ndi
115
- end
116
- end
117
- if splat_position != = 0
118
- cnt[splat_position] = max (0 , remaining)
84
+ _to_indices (a, inds, IndexStyle (A), static (ndims (A)), IndicesInfo (I))
85
+ end
86
+ @generated function _to_indices (a, inds, :: S , :: StaticInt{N} , :: IndicesInfo{NI,NS,IS} ) where {S,N,NI,NS,IS}
87
+ _to_indices_expr (S, N, NI, NS, IS)
88
+ end
89
+ function _to_indices_expr (S:: DataType , N:: Int , ni, ns, is)
90
+ blk = Expr (:block , Expr (:meta , :inline ))
91
+ # check to see if we are dealing with linear indexing over a multidimensional array
92
+ if length (ni) == 1 && ni[1 ] === 1
93
+ push! (blk. args, :((to_index (LazyAxis {:} (a), getfield (inds, 1 )),)))
119
94
else
120
- # if there are additional trailing dimensions not consumed by the index then we have
121
- # to assume it's linear indexing or that these are trailing dimensions.
122
- cnt[end ] += max (0 , remaining)
123
- end
95
+ indsexpr = Expr (:tuple )
96
+ ndi = Int[]
97
+ nds = Int[]
98
+ isi = Bool[]
99
+ # 1. unwrap AbstractCartesianIndex, CartesianIndices, Indices
100
+ for i in 1 : length (ns)
101
+ ns_i = ns[i]
102
+ if ns_i isa Tuple
103
+ for j in 1 : length (ns_i)
104
+ push! (ndi, 1 )
105
+ push! (nds, ns_i[j])
106
+ push! (isi, false )
107
+ push! (indsexpr. args, :(getfield (getfield (getfield (inds, $ i), 1 ), $ j)))
108
+ end
109
+ else
110
+ push! (indsexpr. args, :(getfield (inds, $ i)))
111
+ push! (ndi, ni[i])
112
+ push! (nds, ns_i)
113
+ push! (isi, is[i])
114
+ end
115
+ end
124
116
125
- t = Expr (:tuple )
126
- dim = 0
127
- for i in 1 : known_length (NDI)
128
- if i === known_length (NDI) && S <: IndexLinear
129
- ICall = :LinearIndices
130
- else
131
- ICall = :CartesianIndices
117
+ # 2. find splat indices
118
+ splat_position = 0
119
+ remaining = N
120
+ for i in eachindex (ndi, nds, isi)
121
+ if isi[i] && splat_position == 0
122
+ splat_position = i
123
+ else
124
+ remaining -= ndi[i]
125
+ end
132
126
end
133
- c = cnt[i]
134
- iexpr = :(@inbounds (getfield (inds, $ i)):: $ (I. parameters[i]))
135
- if dim === N
136
- push! (t. args, :(to_index ($ (ICall)(()), $ iexpr)))
137
- elseif c === 1
138
- dim += 1
139
- push! (t. args, :(to_index (@inbounds (getfield (axs, $ dim)), $ iexpr)))
140
- else
141
- subaxs = Expr (:tuple )
142
- for _ in 1 : c
143
- if dim < N
127
+ if splat_position != = 0
128
+ for _ in 2 : remaining
129
+ insert! (ndi, splat_position, 1 )
130
+ insert! (nds, splat_position, 1 )
131
+ insert! (indsexpr. args, splat_position, indsexpr. args[splat_position])
132
+ end
133
+ end
134
+
135
+ # 3. insert `to_index` calls
136
+ dim = 0
137
+ nndi = length (ndi)
138
+ for i in 1 : nndi
139
+ ndi_i = ndi[i]
140
+ if ndi_i == 1
141
+ dim += 1
142
+ indsexpr. args[i] = :(to_index ($ (_axis_expr (N, dim)), $ (indsexpr. args[i])))
143
+ else
144
+ subaxs = Expr (:tuple )
145
+ for _ in 1 : ndi_i
144
146
dim += 1
145
- push! (subaxs. args, :(@inbounds (getfield (axs, $ dim))))
147
+ push! (subaxs. args, _axis_expr (N, dim))
148
+ end
149
+ if i == nndi && S <: IndexLinear
150
+ indsexpr. args[i] = :(to_index (LinearIndices ($ (subaxs)), $ (indsexpr. args[i])))
151
+ else
152
+ indsexpr. args[i] = :(to_index (CartesianIndices ($ (subaxs)), $ (indsexpr. args[i])))
146
153
end
147
154
end
148
- push! (t. args, :(to_index ($ (ICall)($ subaxs), $ iexpr)))
149
155
end
156
+ push! (blk. args, Expr (:(= ), :axs , :(lazy_axes (a))))
157
+ push! (blk. args, :(_flatten_tuples ($ (indsexpr))))
158
+ end
159
+ return blk
160
+ end
161
+
162
+ function _axis_expr (N:: Int , d:: Int )
163
+ if d <= N
164
+ :(getfield (axs, $ d))
165
+ else # ndims(a)+ can only have indices 1:1
166
+ :($ (SOneTo (1 )))
150
167
end
151
- Expr (:block ,
152
- Expr (:meta , :inline ),
153
- Expr (:(= ), :axs , :(lazy_axes (A))),
154
- :(_flatten_tuples ($ t))
155
- )
156
168
end
169
+
157
170
@generated function _flatten_tuples (inds:: I ) where {I}
158
171
t = Expr (:tuple )
159
172
for i in 1 : known_length (I)
@@ -409,7 +422,7 @@ _output_shape(x::AbstractRange) = (Base.length(x),)
409
422
end
410
423
_known_first_isone (ind) = known_first (ind) != = nothing && isone (known_first (ind))
411
424
@inline function unsafe_get_collection (A:: LinearIndices{N} , inds) where {N}
412
- if Base. length (inds) === 1 && isone ( _ndims_index ( typeof (inds), static ( 1 )))
425
+ if Base. length (inds) === 1 && ndims_index ( typeof (first ( inds))) === 1
413
426
return @inbounds (eachindex (A)[first (inds)])
414
427
elseif stride_preserving_index (typeof (inds)) === True () &&
415
428
reduce_tup (& , map (_known_first_isone, inds))
@@ -464,7 +477,6 @@ function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) wh
464
477
end
465
478
end
466
479
467
-
468
480
function unsafe_setindex! (A:: Array{T} , v) where {T}
469
481
Base. arrayset (false , A, convert (T, v):: T , 1 )
470
482
end
0 commit comments