74
74
75
75
76
76
"""
77
- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77
+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
78
78
79
79
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
80
80
"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
84
84
Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} = ntuple (d -> Val {d == N} (), Val {D} ())
85
85
86
86
"""
87
- If the contiguous dimension is not the dimension with `Stride_rank {1}`:
87
+ If the contiguous dimension is not the dimension with `StrideRank {1}`:
88
88
"""
89
89
struct ContiguousBatch{N} end
90
90
Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
91
91
_get (:: ContiguousBatch{N} ) where {N} = N
92
92
93
93
"""
94
- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94
+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
95
95
96
96
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
97
97
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126
126
@inline Base. getindex (:: StrideRank{R} , :: Val{I} ) where {R,I} = StrideRank {permute(R, I)} ()
127
127
128
128
"""
129
- rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129
+ rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130
130
131
131
Returns the `sortperm` of the stride ranks.
132
132
"""
@@ -197,7 +197,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197
197
@inline Base. getindex (:: DenseDims{D} , i:: Integer ) where {D} = D[i]
198
198
@inline Base. getindex (:: DenseDims{D} , :: Val{I} ) where {D,I} = DenseDims {permute(D, I)} ()
199
199
"""
200
- dense_dims(::Type{T}) -> NTuple{N,Bool}
200
+ dense_dims(::Type{T}) -> NTuple{N,Bool}
201
201
202
202
Returns a tuple of indicators for whether each axis is dense.
203
203
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -274,8 +274,217 @@ while still producing correct behavior when using valid cartesian indices, such
274
274
strides (A) = Base. strides (A)
275
275
strides (A, d) = strides (A)[to_dims (A, d)]
276
276
277
+ @generated function _perm_tuple (:: Type{T} , :: Val{P} ) where {T,P}
278
+ out = Expr (:curly , :Tuple )
279
+ for p in P
280
+ push! (out. args, :(T. parameters[$ p]))
281
+ end
282
+ Expr (:block , Expr (:meta , :inline ), out)
283
+ end
284
+
285
+ """
286
+ axes_types(::Type{T}[, d]) -> Type
287
+
288
+ Returns the type of the axes for `T`
289
+ """
290
+ axes_types (x) = axes_types (typeof (x))
291
+ axes_types (x, d) = axes_types (typeof (x), d)
292
+ @inline axes_types (:: Type{T} , d) where {T} = axes_types (T). parameters[to_dims (T, d)]
293
+ function axes_types (:: Type{T} ) where {T}
294
+ if parent_type (T) <: T
295
+ return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims (T)}}
296
+ else
297
+ return axes_types (parent_type (T))
298
+ end
299
+ end
300
+ axes_types (:: Type{T} ) where {T<: Adjoint } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
301
+ axes_types (:: Type{T} ) where {T<: Transpose } = _perm_tuple (axes_types (parent_type (T)), Val ((2 , 1 )))
302
+ function axes_types (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
303
+ return _perm_tuple (axes_types (parent_type (T)), Val (I1))
304
+ end
305
+ function axes_types (:: Type{T} ) where {T<: OptionallyStaticRange }
306
+ if known_length (T) === nothing
307
+ return Tuple{OptionallyStaticUnitRange{One,Int}}
308
+ else
309
+ return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length (T) - 1 }}}
310
+ end
311
+ end
312
+
313
+ @inline function axes_types (:: Type{T} ) where {P,I,T<: SubArray{<:Any,<:Any,P,I} }
314
+ return _sub_axes_types (Val (ArrayStyle (T)), I, axes_types (P))
315
+ end
316
+ @generated function _sub_axes_types (:: Val{S} , :: Type{I} , :: Type{PI} ) where {S,I<: Tuple ,PI<: Tuple }
317
+ out = Expr (:curly , :Tuple )
318
+ d = 1
319
+ for i in I. parameters
320
+ ad = argdims (S, i)
321
+ if ad > 0
322
+ push! (out. args, :(sub_axis_type ($ (PI. parameters[d]), $ i)))
323
+ d += ad
324
+ else
325
+ d += 1
326
+ end
327
+ end
328
+ Expr (:block , Expr (:meta , :inline ), out)
329
+ end
330
+
331
+ @inline function axes_types (:: Type{T} ) where {T<: Base.ReinterpretArray }
332
+ return _reinterpret_axes_types (axes_type (parent_type (T)), eltype (T), eltype (parent_type (T)))
333
+ end
334
+ @generated function _reinterpret_axes_types (:: Type{I} , :: Type{T} , :: Type{S} ) where {I<: Tuple ,T,S}
335
+ out = Expr (:curly , :Tuple )
336
+ for i in 1 : length (T. parameters)
337
+ if i === 1
338
+ push! (out. args, :(reinterpret_axis_type ($ (I. parameters[1 ]), $ T, $ S)))
339
+ else
340
+ # FIXME double check this once I've slept
341
+ push! (out. args, :($ (I. parameters[i])))
342
+ end
343
+ end
344
+ Expr (:block , Expr (:meta , :inline ), out)
345
+ end
346
+
347
+
348
+ # These methods help handle identifying axes that dont' directly propagate from the
349
+ # parent array axes. They may be worth making a formal part of the API, as they provide
350
+ # a low traffic spot to change what axes_types produces.
351
+ @inline function sub_axis_type (:: Type{A} , :: Type{I} ) where {A,I}
352
+ if known_length (I) === nothing
353
+ return OptionallyStaticUnitRange{One,Int}
354
+ else
355
+ return OptionallyStaticUnitRange{One,StaticInt{known_length (I)}}
356
+ end
357
+ end
358
+
359
+ @inline function reinterpret_axis_type (:: Type{A} , :: Type{T} , :: Type{S} ) where {A,T,S}
360
+ if known_length (A) === nothing
361
+ return OptionallyStaticUnitRange{One,Int}
362
+ else
363
+ return OptionallyStaticUnitRange{One,StaticInt{Int (known_length (A) / (sizeof (T) / sizeof (S))) - 1 }}
364
+ end
365
+ end
366
+
367
+ """
368
+ known_offsets(::Type{T}[, d]) -> Tuple
369
+
370
+ Returns a tuple of offset values known at compile time. If the offset of a given axis is
371
+ not known at compile time `nothing` is returned its position.
372
+ """
373
+ @inline known_offsets (x, d) = known_offsets (x)[to_dims (x, d)]
374
+ known_offsets (x) = known_offsets (typeof (x))
375
+ @generated function known_offsets (:: Type{T} ) where {T}
376
+ out = Expr (:tuple )
377
+ for p in axes_types (T). parameters
378
+ push! (out. args, known_first (p))
379
+ end
380
+ return out
381
+ end
382
+
383
+ """
384
+ known_size(::Type{T}[, d]) -> Tuple
385
+
386
+ Returns the size of each dimension for `T` known at compile time. If a dimension does not
387
+ have a known size along a dimension then `nothing` is returned in its position.
388
+ """
389
+ @inline known_size (x, d) = known_size (x)[to_dims (x, d)]
390
+ known_size (x) = known_size (typeof (x))
391
+ known_size (:: Type{T} ) where {T} = _known_size (axes_types (T))
392
+ @generated function _known_size (:: Type{Axs} ) where {Axs<: Tuple }
393
+ out = Expr (:tuple )
394
+ for p in Axs. parameters
395
+ push! (out. args, :(known_length ($ p)))
396
+ end
397
+ return Expr (:block , Expr (:meta , :inline ), out)
398
+ end
399
+
277
400
"""
278
- offsets(A)
401
+ known_strides(::Type{T}[, d]) -> Tuple
402
+ """
403
+ known_strides (x) = known_strides (typeof (x))
404
+ known_strides (x, d) = known_strides (x)[to_dims (x, d)]
405
+ known_strides (:: Type{T} ) where {T<: Vector } = (1 ,)
406
+ @inline function known_strides (:: Type{T} ) where {T<: Adjoint{<:Any,<:AbstractVector} }
407
+ strd = first (known_strides (parent_type (T)))
408
+ return (strd, strd)
409
+ end
410
+ function known_strides (:: Type{T} ) where {T<: Adjoint }
411
+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
412
+ end
413
+ function known_strides (:: Type{T} ) where {T<: Transpose }
414
+ return permute (known_strides (parent_type (T)), Val {(2,1)} ())
415
+ end
416
+ @inline function known_strides (:: Type{T} ) where {T<: Transpose{<:Any,<:AbstractVector} }
417
+ strd = first (known_strides (parent_type (T)))
418
+ return (strd, strd)
419
+ end
420
+ @inline function known_strides (:: Type{T} ) where {I1,T<: PermutedDimsArray{<:Any,<:Any,I1} }
421
+ return permute (known_strides (parent_type (T)), Val {I1} ())
422
+ end
423
+ @inline function known_strides (:: Type{T} ) where {I1,T<: SubArray{<:Any,<:Any,<:Any,I1} }
424
+ return _sub_strides (Val (ArrayStyle (T)), I1, Val (known_strides (parent_type (T))))
425
+ end
426
+
427
+ @generated function _sub_strides (:: Val{S} , :: Type{I} , :: Val{P} ) where {S,I<: Tuple ,P}
428
+ out = Expr (:tuple )
429
+ d = 1
430
+ for i in I. parameters
431
+ ad = argdims (S, i)
432
+ if ad > 0
433
+ push! (out. args, P[d])
434
+ d += ad
435
+ else
436
+ d += 1
437
+ end
438
+ end
439
+ Expr (:block , Expr (:meta , :inline ), out)
440
+ end
441
+
442
+ function known_strides (:: Type{T} ) where {T}
443
+ if ndims (T) === 1
444
+ return (1 ,)
445
+ else
446
+ return _known_strides (Val (Base. front (known_size (T))))
447
+ end
448
+ end
449
+ @generated function _known_strides (:: Val{S} ) where {S}
450
+ out = Expr (:tuple )
451
+ N = length (S)
452
+ push! (out. args, 1 )
453
+ for s in S
454
+ if s === nothing || out. args[end ] === nothing
455
+ push! (out. args, nothing )
456
+ else
457
+ push! (out. args, out. args[end ] * s)
458
+ end
459
+ end
460
+ return Expr (:block , Expr (:meta , :inline ), out)
461
+ end
462
+
463
+ #=
464
+
465
+ function strides(a::ReinterpretArray)
466
+ a.parent isa StridedArray || ArgumentError("Parent must be strided.") |> throw
467
+ size_to_strides(1, size(a)...)
468
+ end
469
+
470
+ strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
471
+ @generated function _strides(_::Base.ReinterpretArray{T, N, S, A, true}, s::NTuple{N}, ::Contiguous{1}) where {T, N, S, D, A <: Array{S,D}}
472
+ stup = Expr(:tuple, :(One()))
473
+ if D < N
474
+ push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T))))
475
+ end
476
+ for n ∈ 2+(D < N):N
477
+ push!(stup.args, Expr(:ref, :s, n))
478
+ end
479
+ quote
480
+ $(Expr(:meta,:inline))
481
+ @inbounds $stup
482
+ end
483
+ end
484
+ =#
485
+
486
+ """
487
+ offsets(A) -> Tuple
279
488
280
489
Returns offsets of indices with respect to 0. If values are known at compile time,
281
490
it should return them as `Static` numbers.
294
503
strd = stride (parent (x), One ())
295
504
(strd, strd)
296
505
end
297
-
506
+
298
507
@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
299
508
if C ≤ 0 || C > N
300
509
return Expr (:block , Expr (:meta ,:inline ), :s )
@@ -325,15 +534,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325
534
quote
326
535
$ (Expr (:meta ,:inline ))
327
536
@inbounds $ stup
328
- end
537
+ end
329
538
end
330
539
end
331
540
332
- @inline function offsets (x, i)
333
- inds = indices (x, i)
334
- start = known_first (inds)
335
- isnothing (start) ? first (inds) : StaticInt (start)
336
- end
541
+ @inline offsets (x, i) = static_first (indices (x, i))
337
542
# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338
543
# Explicit tuple needed for inference.
339
544
@generated function offsets (A:: AbstractArray{<:Any,N} ) where {N}
0 commit comments