74
74
75
75
76
76
"""
77
- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77
+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
78
78
79
79
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
80
80
"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
84
84
Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} = ntuple (d -> Val {d == N} (), Val {D} ())
85
85
86
86
"""
87
- If the contiguous dimension is not the dimension with `Stride_rank {1}`:
87
+ If the contiguous dimension is not the dimension with `StrideRank {1}`:
88
88
"""
89
89
struct ContiguousBatch{N} end
90
90
Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
91
91
_get (:: ContiguousBatch{N} ) where {N} = N
92
92
93
93
"""
94
- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94
+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
95
95
96
96
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
97
97
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126
126
@inline Base. getindex (:: StrideRank{R} , :: Val{I} ) where {R,I} = StrideRank {permute(R, I)} ()
127
127
128
128
"""
129
- rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129
+ rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130
130
131
131
Returns the `sortperm` of the stride ranks.
132
132
"""
@@ -177,7 +177,9 @@ stride_rank(x, i) = stride_rank(x)[i]
177
177
stride_rank (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = StrideRank {ntuple(identity, Val{N}())} ()
178
178
179
179
"""
180
- is_column_major(A) -> Val{true/false}()
180
+ is_column_major(A) -> Val{true/false}()
181
+
182
+ Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
181
183
"""
182
184
is_column_major (A) = is_column_major (stride_rank (A), contiguous_batch_size (A))
183
185
is_column_major (:: Nothing , :: Any ) = Val {false} ()
@@ -197,7 +199,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197
199
@inline Base. getindex (:: DenseDims{D} , i:: Integer ) where {D} = D[i]
198
200
@inline Base. getindex (:: DenseDims{D} , :: Val{I} ) where {D,I} = DenseDims {permute(D, I)} ()
199
201
"""
200
- dense_dims(::Type{T}) -> NTuple{N,Bool}
202
+ dense_dims(::Type{T}) -> NTuple{N,Bool}
201
203
202
204
Returns a tuple of indicators for whether each axis is dense.
203
205
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -250,7 +252,7 @@ permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}(
250
252
end
251
253
252
254
"""
253
- strides(A)
255
+ strides(A) -> Tuple
254
256
255
257
Returns the strides of array `A`. If any strides are known at compile time,
256
258
these should be returned as `Static` numbers. For example:
@@ -274,8 +276,196 @@ while still producing correct behavior when using valid cartesian indices, such
274
276
strides (A) = Base. strides (A)
275
277
strides (A, d) = strides (A)[to_dims (A, d)]
276
278
279
+ @generated function _perm_tuple (:: Type{T} , :: Val{P} ) where {T,P}
280
+ out = Expr (:curly , :Tuple )
281
+ for p in P
282
+ push! (out. args, :(T. parameters[$ p]))
283
+ end
284
+ Expr (:block , Expr (:meta , :inline ), out)
285
+ end
286
+
287
+ """
288
+ axes_types(::Type{T}[, d]) -> Type
289
+
290
+ Returns the type of the axes for `T`
291
+ """
292
+ axes_types (x) = axes_types (typeof (x))
293
+ axes_types (x, d) = axes_types (typeof (x), d)
294
+ @inline axes_types (:: Type{T} , d) where {T} = axes_types (T). parameters[to_dims (T, d)]
295
+ function axes_types (:: Type{T} ) where {T}
296
+ if parent_type (T) <: T
297
+ return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims (T)}}
298
+ else
299
+ return axes_types (parent_type (T))
300
+ end
301
+ end
302
+ axes_types (:: Type{T} ) where {T<: Adjoint } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
303
+ axes_types (:: Type{T} ) where {T<: Transpose } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
304
+ function axes_types (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
305
+ return _perm_tuple (axes_types (parent_type (T)), Val (I1))
306
+ end
307
+ function axes_types (:: Type{T} ) where {T<: AbstractRange }
308
+ if known_length (T) === nothing
309
+ return Tuple{OptionallyStaticUnitRange{One,Int}}
310
+ else
311
+ return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length (T)}}}
312
+ end
313
+ end
314
+
315
+ @inline function axes_types (:: Type{T} ) where {P,I,T<: SubArray{<:Any,<:Any,P,I} }
316
+ return _sub_axes_types (Val (ArrayStyle (T)), I, axes_types (P))
317
+ end
318
+ @generated function _sub_axes_types (:: Val{S} , :: Type{I} , :: Type{PI} ) where {S,I<: Tuple ,PI<: Tuple }
319
+ out = Expr (:curly , :Tuple )
320
+ d = 1
321
+ for i in I. parameters
322
+ ad = argdims (S, i)
323
+ if ad > 0
324
+ push! (out. args, :(sub_axis_type ($ (PI. parameters[d]), $ i)))
325
+ d += ad
326
+ else
327
+ d += 1
328
+ end
329
+ end
330
+ Expr (:block , Expr (:meta , :inline ), out)
331
+ end
332
+
333
+ @inline function axes_types (:: Type{T} ) where {T<: Base.ReinterpretArray }
334
+ return _reinterpret_axes_types (axes_types (parent_type (T)), eltype (T), eltype (parent_type (T)))
335
+ end
336
+ @generated function _reinterpret_axes_types (:: Type{I} , :: Type{T} , :: Type{S} ) where {I<: Tuple ,T,S}
337
+ out = Expr (:curly , :Tuple )
338
+ for i in 1 : length (I. parameters)
339
+ if i === 1
340
+ push! (out. args, reinterpret_axis_type (I. parameters[1 ], T, S))
341
+ else
342
+ push! (out. args, I. parameters[i])
343
+ end
344
+ end
345
+ Expr (:block , Expr (:meta , :inline ), out)
346
+ end
347
+
348
+
349
+ # These methods help handle identifying axes that dont' directly propagate from the
350
+ # parent array axes. They may be worth making a formal part of the API, as they provide
351
+ # a low traffic spot to change what axes_types produces.
352
+ @inline function sub_axis_type (:: Type{A} , :: Type{I} ) where {A,I}
353
+ if known_length (I) === nothing
354
+ return OptionallyStaticUnitRange{One,Int}
355
+ else
356
+ return OptionallyStaticUnitRange{One,StaticInt{known_length (I)}}
357
+ end
358
+ end
359
+
360
+ @inline function reinterpret_axis_type (:: Type{A} , :: Type{T} , :: Type{S} ) where {A,T,S}
361
+ if known_length (A) === nothing
362
+ return OptionallyStaticUnitRange{One,Int}
363
+ else
364
+ return OptionallyStaticUnitRange{One,StaticInt{Int (known_length (A) / (sizeof (T) / sizeof (S)))}}
365
+ end
366
+ end
367
+
277
368
"""
278
- offsets(A)
369
+ known_offsets(::Type{T}[, d]) -> Tuple
370
+
371
+ Returns a tuple of offset values known at compile time. If the offset of a given axis is
372
+ not known at compile time `nothing` is returned its position.
373
+ """
374
+ @inline known_offsets (x, d) = known_offsets (x)[to_dims (x, d)]
375
+ known_offsets (x) = known_offsets (typeof (x))
376
+ @generated function known_offsets (:: Type{T} ) where {T}
377
+ out = Expr (:tuple )
378
+ for p in axes_types (T). parameters
379
+ push! (out. args, known_first (p))
380
+ end
381
+ return out
382
+ end
383
+
384
+ """
385
+ known_size(::Type{T}[, d]) -> Tuple
386
+
387
+ Returns the size of each dimension for `T` known at compile time. If a dimension does not
388
+ have a known size along a dimension then `nothing` is returned in its position.
389
+ """
390
+ @inline known_size (x, d) = known_size (x)[to_dims (x, d)]
391
+ known_size (x) = known_size (typeof (x))
392
+ known_size (:: Type{T} ) where {T} = _known_size (axes_types (T))
393
+ @generated function _known_size (:: Type{Axs} ) where {Axs<: Tuple }
394
+ out = Expr (:tuple )
395
+ for p in Axs. parameters
396
+ push! (out. args, :(known_length ($ p)))
397
+ end
398
+ return Expr (:block , Expr (:meta , :inline ), out)
399
+ end
400
+
401
+ """
402
+ known_strides(::Type{T}[, d]) -> Tuple
403
+
404
+ Returns the strides of array `A` known at compile time. Any strides that are not known at
405
+ compile time are represented by `nothing`.
406
+ """
407
+ known_strides (x) = known_strides (typeof (x))
408
+ known_strides (x, d) = known_strides (x)[to_dims (x, d)]
409
+ known_strides (:: Type{T} ) where {T<: Vector } = (1 ,)
410
+ @inline function known_strides (:: Type{T} ) where {T<: Adjoint{<:Any,<:AbstractVector} }
411
+ strd = first (known_strides (parent_type (T)))
412
+ return (strd, strd)
413
+ end
414
+ function known_strides (:: Type{T} ) where {T<: Adjoint }
415
+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
416
+ end
417
+ function known_strides (:: Type{T} ) where {T<: Transpose }
418
+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
419
+ end
420
+ @inline function known_strides (:: Type{T} ) where {T<: Transpose{<:Any,<:AbstractVector} }
421
+ strd = first (known_strides (parent_type (T)))
422
+ return (strd, strd)
423
+ end
424
+ @inline function known_strides (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
425
+ return permute (known_strides (parent_type (T)), Val {I1} ())
426
+ end
427
+ @inline function known_strides (:: Type{T} ) where {I1,T<: SubArray{<:Any,<:Any,<:Any,I1} }
428
+ return _sub_strides (Val (ArrayStyle (T)), I1, Val (known_strides (parent_type (T))))
429
+ end
430
+
431
+ @generated function _sub_strides (:: Val{S} , :: Type{I} , :: Val{P} ) where {S,I<: Tuple ,P}
432
+ out = Expr (:tuple )
433
+ d = 1
434
+ for i in I. parameters
435
+ ad = argdims (S, i)
436
+ if ad > 0
437
+ push! (out. args, P[d])
438
+ d += ad
439
+ else
440
+ d += 1
441
+ end
442
+ end
443
+ Expr (:block , Expr (:meta , :inline ), out)
444
+ end
445
+
446
+ function known_strides (:: Type{T} ) where {T}
447
+ if ndims (T) === 1
448
+ return (1 ,)
449
+ else
450
+ return _known_strides (Val (Base. front (known_size (T))))
451
+ end
452
+ end
453
+ @generated function _known_strides (:: Val{S} ) where {S}
454
+ out = Expr (:tuple )
455
+ N = length (S)
456
+ push! (out. args, 1 )
457
+ for s in S
458
+ if s === nothing || out. args[end ] === nothing
459
+ push! (out. args, nothing )
460
+ else
461
+ push! (out. args, out. args[end ] * s)
462
+ end
463
+ end
464
+ return Expr (:block , Expr (:meta , :inline ), out)
465
+ end
466
+
467
+ """
468
+ offsets(A) -> Tuple
279
469
280
470
Returns offsets of indices with respect to 0. If values are known at compile time,
281
471
it should return them as `Static` numbers.
294
484
strd = stride (parent (x), One ())
295
485
(strd, strd)
296
486
end
297
-
487
+
298
488
@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
299
489
if C ≤ 0 || C > N
300
490
return Expr (:block , Expr (:meta ,:inline ), :s )
@@ -325,15 +515,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325
515
quote
326
516
$ (Expr (:meta ,:inline ))
327
517
@inbounds $ stup
328
- end
518
+ end
329
519
end
330
520
end
331
521
332
- @inline function offsets (x, i)
333
- inds = indices (x, i)
334
- start = known_first (inds)
335
- isnothing (start) ? first (inds) : StaticInt (start)
336
- end
522
+ @inline offsets (x, i) = static_first (indices (x, i))
337
523
# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338
524
# Explicit tuple needed for inference.
339
525
@generated function offsets (A:: AbstractArray{<:Any,N} ) where {N}
344
530
Expr (:block , Expr (:meta , :inline ), t)
345
531
end
346
532
533
+ @inline size (v:: AbstractVector ) = (static_length (axes_types (v, 1 )),)
347
534
@inline size (B:: Union{Transpose{T,A},Adjoint{T,A}} ) where {T,A<: AbstractMatrix{T} } = permute (size (parent (B)), Val {(2,1)} ())
348
535
@inline size (B:: PermutedDimsArray{T,N,I1,I2,A} ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = permute (size (parent (B)), Val {I1} ())
349
536
@inline size (A:: AbstractArray , :: StaticInt{N} ) where {N} = size (A)[N]
0 commit comments