@@ -183,6 +183,8 @@ to_index(::MyIndexStyle, axis, arg) = ...
183
183
"""
184
184
to_index (x, i:: Slice ) = i
185
185
to_index (x, :: Colon ) = indices (x)
186
+ to_index (:: LinearIndices{0,Tuple{}} , :: Colon ) = Slice (static (1 ): static (1 ))
187
+ to_index (:: CartesianIndices{0,Tuple{}} , :: Colon ) = Slice (static (1 ): static (1 ))
186
188
# logical indexing
187
189
to_index (x, i:: AbstractArray{Bool} ) = LogicalIndex (i)
188
190
to_index (x:: LinearIndices , i:: AbstractArray{Bool} ) = LogicalIndex {Int} (i)
@@ -251,10 +253,10 @@ indices calling [`to_axis`](@ref).
251
253
end
252
254
end
253
255
# drop this dimension
254
- to_axes (A, a:: Tuple , i:: Tuple{<:CanonicalInt,Vararg{Any}} ) = to_axes (A, tail (a), tail (i))
256
+ to_axes (A, a:: Tuple , i:: Tuple{<:CanonicalInt,Vararg{Any}} ) = to_axes (A, _maybe_tail (a), tail (i))
255
257
to_axes (A, a:: Tuple , i:: Tuple{I,Vararg{Any}} ) where {I} = _to_axes (StaticInt (ndims_index (I)), A, a, i)
256
258
function _to_axes (:: StaticInt{1} , A, axs:: Tuple , inds:: Tuple )
257
- return (to_axis (first (axs), first (inds)), to_axes (A, tail (axs), tail (inds))... )
259
+ return (to_axis (_maybe_first (axs), first (inds)), to_axes (A, _maybe_tail (axs), tail (inds))... )
258
260
end
259
261
@propagate_inbounds function _to_axes (:: StaticInt{N} , A, axs:: Tuple , inds:: Tuple ) where {N}
260
262
axes_front, axes_tail = Base. IteratorsMD. split (axs, Val (N))
268
270
to_axes (A, :: Tuple{Ax,Vararg{Any}} , :: Tuple{} ) where {Ax} = ()
269
271
to_axes (A, :: Tuple{} , :: Tuple{} ) = ()
270
272
273
+ _maybe_first (:: Tuple{} ) = static (1 ): static (1 )
274
+ _maybe_first (t:: Tuple ) = first (t)
275
+ _maybe_tail (:: Tuple{} ) = ()
276
+ _maybe_tail (t:: Tuple ) = tail (t)
277
+
271
278
"""
272
279
to_axis(old_axis, index) -> new_axis
273
280
@@ -349,7 +356,9 @@ unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
349
356
end
350
357
351
358
unsafe_getindex (A:: LinearIndices , i:: CanonicalInt ) = Int (i)
352
- unsafe_getindex (A:: CartesianIndices , i:: CanonicalInt , ii:: Vararg{CanonicalInt} ) = CartesianIndex (i, ii... )
359
+ unsafe_getindex (A:: CartesianIndices{N} , ii:: Vararg{CanonicalInt,N} ) where {N} = CartesianIndex (ii... )
360
+ unsafe_getindex (A:: CartesianIndices , ii:: Vararg{CanonicalInt} ) =
361
+ unsafe_getindex (A, Base. front (ii)... )
353
362
unsafe_getindex (A:: CartesianIndices , i:: CanonicalInt ) = @inbounds (A[i])
354
363
355
364
unsafe_getindex (A:: ReshapedArray , i:: CanonicalInt ) = @inbounds (parent (A)[i])
@@ -378,18 +387,37 @@ function unsafe_get_collection(A, inds)
378
387
end
379
388
_ints2range (x:: CanonicalInt ) = x: x
380
389
_ints2range (x:: AbstractRange ) = x
390
+ # apply _ints2range to front N elements
391
+ _ints2range_front (:: Val{N} , ind, inds... ) where {N} =
392
+ (_ints2range (ind), _ints2range_front (Val (N - 1 ), inds... )... )
393
+ _ints2range_front (:: Val{0} , ind, inds... ) = ()
394
+ _ints2range_front (:: Val{0} ) = ()
395
+ # get output shape with given indices
396
+ _output_shape (:: CanonicalInt , inds... ) = _output_shape (inds... )
397
+ _output_shape (ind:: AbstractRange , inds... ) = (length (ind), _output_shape (inds... )... )
398
+ _output_shape (:: CanonicalInt ) = ()
399
+ _output_shape (x:: AbstractRange ) = (length (x),)
381
400
@inline function unsafe_get_collection (A:: CartesianIndices{N} , inds) where {N}
382
401
if (Base. length (inds) === 1 && N > 1 ) || stride_preserving_index (typeof (inds)) === False ()
383
402
return Base. _getindex (IndexStyle (A), A, inds... )
384
403
else
385
- return CartesianIndices (to_axes (A, _ints2range .(inds)))
404
+ return reshape (
405
+ CartesianIndices (_ints2range_front (Val (N), inds... )),
406
+ _output_shape (inds... )
407
+ )
386
408
end
387
409
end
410
+ _known_first_isone (ind) = known_first (ind) != = nothing && isone (known_first (ind))
388
411
@inline function unsafe_get_collection (A:: LinearIndices{N} , inds) where {N}
389
412
if Base. length (inds) === 1 && isone (_ndims_index (typeof (inds), static (1 )))
390
413
return @inbounds (eachindex (A)[first (inds)])
391
- elseif stride_preserving_index (typeof (inds)) === True ()
392
- return LinearIndices (to_axes (A, _ints2range .(inds)))
414
+ elseif stride_preserving_index (typeof (inds)) === True () &&
415
+ reduce_tup (& , map (_known_first_isone, inds))
416
+ # create a LinearIndices when first(ind) != 1 is imposable
417
+ return reshape (
418
+ LinearIndices (_ints2range_front (Val (N), inds... )),
419
+ _output_shape (inds... )
420
+ )
393
421
else
394
422
return Base. _getindex (IndexStyle (A), A, inds... )
395
423
end
0 commit comments