1
1
2
2
"""
3
- argdims(::IndexStyle, ::Type{T})
3
+ ArrayStyle(::Type{A})
4
+
5
+ Used to customize the meaning of indexing arguments in the context of a given array `A`.
6
+
7
+ See also: [`argdims`](@ref), [`UnsafeIndex`](@ref)
8
+ """
9
+ abstract type ArrayStyle end
10
+
11
+ struct DefaultArrayStyle <: ArrayStyle end
12
+
13
+ ArrayStyle (A) = ArrayStyle (typeof (A))
14
+ ArrayStyle (:: Type{A} ) where {A} = DefaultArrayStyle ()
15
+
16
+ """
17
+ argdims(::ArrayStyle, ::Type{T})
4
18
5
19
Whats the dimensionality of the indexing argument of type `T`?
6
20
"""
7
- argdims (A, x) = argdims (IndexStyle (A), typeof (x))
8
- argdims (s:: IndexStyle , x) = argdims (s, typeof (x))
21
+ argdims (x, arg) = argdims (x, typeof (arg))
22
+ argdims (x, :: Type{T} ) where {T} = argdims (ArrayStyle (x), T)
23
+ argdims (s:: ArrayStyle , arg) = argdims (s, typeof (arg))
9
24
# single elements initially map to 1 dimension but that dimension is subsequently dropped.
10
- argdims (:: IndexStyle , :: Type{T} ) where {T} = 0
11
- argdims (:: IndexStyle , :: Type{T} ) where {T<: Colon } = 1
12
- argdims (:: IndexStyle , :: Type{T} ) where {T<: AbstractArray } = ndims (T)
13
- argdims (:: IndexStyle , :: Type{T} ) where {N,T<: CartesianIndex{N} } = N
14
- argdims (:: IndexStyle , :: Type{T} ) where {N,T<: AbstractArray{CartesianIndex{N}} } = N
15
- argdims (:: IndexStyle , :: Type{T} ) where {N,T<: AbstractArray{<:Any,N} } = N
16
- argdims (:: IndexStyle , :: Type{T} ) where {N,T<: LogicalIndex{<:Any,<:AbstractArray{Bool,N}} } = N
17
- @generated function argdims (s:: IndexStyle , :: Type{T} ) where {N,T<: Tuple{Vararg{<:Any,N}} }
25
+ argdims (:: ArrayStyle , :: Type{T} ) where {T} = 0
26
+ argdims (:: ArrayStyle , :: Type{T} ) where {T<: Colon } = 1
27
+ argdims (:: ArrayStyle , :: Type{T} ) where {T<: AbstractArray } = ndims (T)
28
+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: CartesianIndex{N} } = N
29
+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractArray{CartesianIndex{N}} } = N
30
+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: AbstractArray{<:Any,N} } = N
31
+ argdims (:: ArrayStyle , :: Type{T} ) where {N,T<: LogicalIndex{<:Any,<:AbstractArray{Bool,N}} } = N
32
+ @generated function argdims (s:: ArrayStyle , :: Type{T} ) where {N,T<: Tuple{Vararg{<:Any,N}} }
18
33
e = Expr (:tuple )
19
34
for p in T. parameters
20
35
push! (e. args, :(ArrayInterface. argdims (s, $ p)))
21
36
end
22
37
Expr (:block , Expr (:meta , :inline ), e)
23
38
end
24
39
40
+ """
41
+ UnsafeIndex(::ArrayStyle, ::Type{I})
42
+
43
+ `UnsafeIndex` controls how indices that have been bounds checked and converted to
44
+ native axes' indices are used to return the stored values of an array. For example,
45
+ if the indices at each dimension are single integers then `UnsafeIndex(array, inds)` returns
46
+ `UnsafeGetElement()`. Conversely, if any of the indices are vectors then `UnsafeGetCollection()`
47
+ is returned, indicating that a new array needs to be reconstructed. This method permits
48
+ customizing the terminal behavior of the indexing pipeline based on arguments passed
49
+ to `ArrayInterface.getindex`. New subtypes of `UnsafeIndex` should define `promote_rule`.
50
+ """
51
+ abstract type UnsafeIndex end
52
+
53
+ struct UnsafeGetElement <: UnsafeIndex end
54
+
55
+ struct UnsafeGetCollection <: UnsafeIndex end
56
+
57
+ UnsafeIndex (x, i) = UnsafeIndex (x, typeof (i))
58
+ UnsafeIndex (x, :: Type{I} ) where {I} = UnsafeIndex (ArrayStyle (x), I)
59
+ UnsafeIndex (s:: ArrayStyle , i) = UnsafeIndex (s, typeof (i))
60
+ UnsafeIndex (:: ArrayStyle , :: Type{I} ) where {I} = UnsafeGetElement ()
61
+ UnsafeIndex (:: ArrayStyle , :: Type{I} ) where {I<: AbstractArray } = UnsafeGetCollection ()
62
+
63
+ Base. promote_rule (:: Type{X} , :: Type{Y} ) where {X<: UnsafeIndex ,Y<: UnsafeGetElement } = X
64
+
65
+ @generated function UnsafeIndex (s:: ArrayStyle , :: Type{T} ) where {N,T<: Tuple{Vararg{<:Any,N}} }
66
+ if N === 0
67
+ return UnsafeGetElement ()
68
+ else
69
+ e = Expr (:call , promote_type)
70
+ for p in T. parameters
71
+ push! (e. args, :(typeof (ArrayInterface. UnsafeIndex (s, $ p))))
72
+ end
73
+ return Expr (:block , Expr (:meta , :inline ), Expr (:call , e))
74
+ end
75
+ end
76
+
77
+ # are the indexing arguments provided a linear collection into a multidim collection
78
+ is_linear_indexing (A, args:: Tuple{Arg} ) where {Arg} = argdims (A, Arg) < 2
79
+ is_linear_indexing (A, args:: Tuple{Arg,Vararg{Any}} ) where {Arg} = false
80
+
25
81
"""
26
82
flatten_args(A, args::Tuple{Arg,Vararg{Any}}) -> Tuple
27
83
@@ -133,27 +189,15 @@ be accomplished using `to_index(axis, arg)`.
133
189
@propagate_inbounds function to_indices (A, args:: Tuple )
134
190
if can_flatten (A, args)
135
191
return to_indices (A, flatten_args (A, args))
192
+ elseif is_linear_indexing (A, args)
193
+ return (to_index (eachindex (IndexLinear (), A), first (args)),)
136
194
else
137
195
return to_indices (A, axes (A), args)
138
196
end
139
197
end
140
- @propagate_inbounds function to_indices (A, args:: Tuple{Arg} ) where {Arg}
141
- if can_flatten (A, args)
142
- return to_indices (A, flatten_args (A, args))
143
- else
144
- if argdims (IndexStyle (A), Arg) > 1
145
- return to_indices (A, axes (A), args)
146
- else
147
- if ndims (A) === 1
148
- return (to_index (axes (A, 1 ), first (args)),)
149
- else
150
- return to_indices (A, (eachindex (A),), args)
151
- end
152
- end
153
- end
154
- end
198
+ @propagate_inbounds to_indices (A, args:: Tuple{} ) = to_indices (A, axes (A), ())
155
199
@propagate_inbounds function to_indices (A, axs:: Tuple , args:: Tuple{Arg,Vararg{Any}} ) where {Arg}
156
- N = argdims (IndexStyle (A) , Arg)
200
+ N = argdims (A , Arg)
157
201
if N > 1
158
202
axes_front, axes_tail = Base. IteratorsMD. split (axs, Val (N))
159
203
return (to_multi_index (axes_front, first (args)), to_indices (A, axes_tail, tail (args))... )
172
216
end
173
217
to_indices (A, axs:: Tuple{} , args:: Tuple{} ) = ()
174
218
219
+
220
+ _multi_check_index (axs:: Tuple , arg) = _multi_check_index (axs, axes (arg))
221
+ function _multi_check_index (axs:: Tuple , arg:: AbstractArray{T} ) where {T<: CartesianIndex }
222
+ return checkindex (Bool, axs, arg)
223
+ end
224
+ _multi_check_index (:: Tuple{} , :: Tuple{} ) = true
225
+ function _multi_check_index (axs:: Tuple , args:: Tuple )
226
+ if checkindex (Bool, first (axs), first (args))
227
+ return _multi_check_index (tail (axs), tail (args))
228
+ else
229
+ return false
230
+ end
231
+ end
175
232
@propagate_inbounds function to_multi_index (axs:: Tuple , arg)
176
- @boundscheck if ! Base . checkbounds_indices (Bool, axs, ( arg,) )
233
+ @boundscheck if ! _multi_check_index ( axs, arg)
177
234
throw (BoundsError (axs, arg))
178
235
end
179
236
return arg
@@ -236,7 +293,6 @@ function unsafe_reconstruct(A::OneTo, data; kwargs...)
236
293
end
237
294
end
238
295
end
239
-
240
296
function unsafe_reconstruct (A:: UnitRange , data; kwargs... )
241
297
if can_change_size (A)
242
298
return typeof (A)(data)
@@ -248,7 +304,6 @@ function unsafe_reconstruct(A::UnitRange, data; kwargs...)
248
304
end
249
305
end
250
306
end
251
-
252
307
function unsafe_reconstruct (A:: OptionallyStaticUnitRange , data; kwargs... )
253
308
if can_change_size (A)
254
309
return typeof (A)(data)
@@ -260,7 +315,6 @@ function unsafe_reconstruct(A::OptionallyStaticUnitRange, data; kwargs...)
260
315
end
261
316
end
262
317
end
263
-
264
318
function unsafe_reconstruct (A:: AbstractUnitRange , data; kwargs... )
265
319
return static_first (data): static_last (data)
266
320
end
284
338
to_axes (A, :: Tuple{Ax,Vararg{Any}} , :: Tuple{} ) where {Ax} = ()
285
339
to_axes (A, :: Tuple{} , :: Tuple{} ) = ()
286
340
@propagate_inbounds function to_axes (A, axs:: Tuple{Ax,Vararg{Any}} , inds:: Tuple{I,Vararg{Any}} ) where {Ax,I}
287
- N = argdims (IndexStyle (A) , I)
341
+ N = argdims (A , I)
288
342
if N === 0
289
343
# drop this dimension
290
344
return to_axes (A, tail (axs), tail (inds))
@@ -330,53 +384,15 @@ Changing indexing based on a given argument from `args` should be done through
330
384
"""
331
385
@propagate_inbounds getindex (A, args... ) = unsafe_getindex (A, to_indices (A, args))
332
386
333
- """
334
- UnsafeIndex <: Function
335
-
336
- `UnsafeIndex` controls how indices that have been bounds checked and converted to
337
- native axes' indices are used to return the stored values of an array. For example,
338
- if the indices at each dimension are single integers than `UnsafeIndex(inds)` returns
339
- `UnsafeElement()`. Conversely, if any of the indices are vectors then `UnsafeCollection()`
340
- is returned, indicating that a new array needs to be reconstructed. This method permits
341
- customizing the terimnal behavior of the indexing pipeline based on arguments passed
342
- to `ArrayInterface.getindex`
343
- """
344
- abstract type UnsafeIndex <: Function end
345
-
346
- struct UnsafeElement <: UnsafeIndex end
347
- const unsafe_element = UnsafeElement ()
348
-
349
- struct UnsafeCollection <: UnsafeIndex end
350
- const unsafe_collection = UnsafeCollection ()
351
-
352
- # 1-arg
353
- UnsafeIndex (x) = UnsafeIndex (typeof (x))
354
- UnsafeIndex (x:: UnsafeIndex ) = x
355
- UnsafeIndex (:: Type{T} ) where {T<: Integer } = unsafe_element
356
- UnsafeIndex (:: Type{T} ) where {T<: AbstractArray } = unsafe_collection
357
-
358
- # 2-arg
359
- UnsafeIndex (x:: UnsafeIndex , y:: UnsafeElement ) = x
360
- UnsafeIndex (x:: UnsafeElement , y:: UnsafeIndex ) = y
361
- UnsafeIndex (x:: UnsafeElement , y:: UnsafeElement ) = x
362
- UnsafeIndex (x:: UnsafeCollection , y:: UnsafeCollection ) = x
363
-
364
-
365
- # tuple
366
- UnsafeIndex (x:: Tuple{I} ) where {I} = UnsafeIndex (I)
367
- @inline function UnsafeIndex (x:: Tuple{I,Vararg{Any}} ) where {I}
368
- return UnsafeIndex (UnsafeIndex (I), UnsafeIndex (tail (x)))
369
- end
370
-
371
387
"""
372
388
unsafe_getindex(A, inds)
373
389
374
390
Indexes into `A` given `inds`. This method assumes that `inds` have already been
375
391
bounds checked.
376
392
"""
377
- unsafe_getindex (A, inds) = unsafe_getindex (UnsafeIndex (inds), A, inds)
378
- unsafe_getindex (:: UnsafeElement , A, inds) = unsafe_get_element (A, inds)
379
- unsafe_getindex (:: UnsafeCollection , A, inds) = unsafe_get_collection (A, inds)
393
+ unsafe_getindex (A, inds) = unsafe_getindex (UnsafeIndex (A, inds), A, inds)
394
+ unsafe_getindex (:: UnsafeGetElement , A, inds) = unsafe_get_element (A, inds)
395
+ unsafe_getindex (:: UnsafeGetCollection , A, inds) = unsafe_get_collection (A, inds)
380
396
381
397
"""
382
398
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
@@ -389,7 +405,9 @@ function unsafe_get_element(A, inds)
389
405
throw (MethodError (unsafe_getindex, (A, inds)))
390
406
end
391
407
function unsafe_get_element (A:: Array , inds)
392
- if inds isa Tuple{Vararg{Int}}
408
+ if length (inds) === 0
409
+ return Base. arrayref (false , A, 1 )
410
+ elseif inds isa Tuple{Vararg{Int}}
393
411
return Base. arrayref (false , A, inds... )
394
412
else
395
413
throw (MethodError (unsafe_get_element, (A, inds)))
@@ -443,14 +461,12 @@ end
443
461
end
444
462
end
445
463
@inline function unsafe_get_collection (A:: LinearIndices{N} , inds) where {N}
446
- if can_preserve_indices (typeof (inds))
464
+ if is_linear_indexing (A, inds)
465
+ return @inbounds (eachindex (A)[first (inds)])
466
+ elseif can_preserve_indices (typeof (inds))
447
467
return LinearIndices (to_axes (A, _ints2range .(inds)))
448
468
else
449
- if length (inds) === 1
450
- return @inbounds (eachindex (A)[first (inds)])
451
- else
452
- return Base. _getindex (IndexStyle (A), A, inds... )
453
- end
469
+ return Base. _getindex (IndexStyle (A), A, inds... )
454
470
end
455
471
end
456
472
474
490
Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
475
491
bounds checked. This step of the processing pipeline can be customized by
476
492
"""
477
- unsafe_setindex! (A, val, inds:: Tuple ) = unsafe_setindex! (UnsafeIndex (inds), A, val, inds)
478
- unsafe_setindex! (:: UnsafeElement , A, val, inds:: Tuple ) = unsafe_set_element! (A, val, inds)
479
- unsafe_setindex! (:: UnsafeCollection , A, val, inds:: Tuple ) = unsafe_set_collection! (A, val, inds)
493
+ unsafe_setindex! (A, val, inds:: Tuple ) = unsafe_setindex! (UnsafeIndex (A, inds), A, val, inds)
494
+ unsafe_setindex! (:: UnsafeGetElement , A, val, inds:: Tuple ) = unsafe_set_element! (A, val, inds)
495
+ unsafe_setindex! (:: UnsafeGetCollection , A, val, inds:: Tuple ) = unsafe_set_collection! (A, val, inds)
480
496
481
497
"""
482
498
unsafe_set_element!(A, val, inds::Tuple)
@@ -489,7 +505,9 @@ function unsafe_set_element!(A, val, inds)
489
505
throw (MethodError (unsafe_set_element!, (A, val, inds)))
490
506
end
491
507
function unsafe_set_element! (A:: Array{T} , val, inds:: Tuple ) where {T}
492
- if inds isa Tuple{Vararg{Int}}
508
+ if length (inds) === 0
509
+ return Base. arrayset (false , A, convert (T, val):: T , 1 )
510
+ elseif inds isa Tuple{Vararg{Int}}
493
511
return Base. arrayset (false , A, convert (T, val):: T , inds... )
494
512
else
495
513
throw (MethodError (unsafe_set_element!, (A, inds)))
0 commit comments