Skip to content

Commit 57b03b3

Browse files
authored
Merge pull request #98 from Tokazama/axes_methods
Make working with axes (and their types) easy
2 parents d799085 + 712ba28 commit 57b03b3

File tree

3 files changed

+274
-26
lines changed

3 files changed

+274
-26
lines changed

src/ArrayInterface.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ parameterless_type(x::Type) = __parameterless_type(x)
1313
"""
1414
parent_type(::Type{T})
1515
16-
Returns the parent array that `x` wraps.
16+
Returns the parent array that type `T` wraps.
1717
"""
1818
parent_type(x) = parent_type(typeof(x))
1919
parent_type(::Type{<:SubArray{T,N,P}}) where {T,N,P} = P
@@ -704,6 +704,12 @@ end
704704
end
705705
end
706706

707+
include("static.jl")
708+
include("ranges.jl")
709+
include("dimensions.jl")
710+
include("indexing.jl")
711+
include("stridelayout.jl")
712+
707713
function __init__()
708714

709715
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
@@ -746,6 +752,10 @@ function __init__()
746752
stride_rank(::Type{T}) where {N, T <: StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}()
747753
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}()
748754
defines_strides(::Type{<:StaticArrays.MArray}) = true
755+
756+
@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
757+
return Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
758+
end
749759
@generated function size(A::StaticArrays.StaticArray{S}) where {S}
750760
t = Expr(:tuple); Sp = S.parameters
751761
for n in 1:length(Sp)
@@ -896,10 +906,4 @@ function __init__()
896906
end
897907
end
898908

899-
include("static.jl")
900-
include("ranges.jl")
901-
include("dimensions.jl")
902-
include("indexing.jl")
903-
include("stridelayout.jl")
904-
905909
end

src/stridelayout.jl

Lines changed: 202 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474

7575

7676
"""
77-
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77+
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7878
7979
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
8080
"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
8484
Base.@pure contiguous_axis_indicator(::Contiguous{N}, ::Val{D}) where {N,D} = ntuple(d -> Val{d == N}(), Val{D}())
8585

8686
"""
87-
If the contiguous dimension is not the dimension with `Stride_rank{1}`:
87+
If the contiguous dimension is not the dimension with `StrideRank{1}`:
8888
"""
8989
struct ContiguousBatch{N} end
9090
Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}()
9191
_get(::ContiguousBatch{N}) where {N} = N
9292

9393
"""
94-
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94+
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
9595
9696
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
9797
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126126
@inline Base.getindex(::StrideRank{R}, ::Val{I}) where {R,I} = StrideRank{permute(R, I)}()
127127

128128
"""
129-
rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129+
rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130130
131131
Returns the `sortperm` of the stride ranks.
132132
"""
@@ -177,7 +177,9 @@ stride_rank(x, i) = stride_rank(x)[i]
177177
stride_rank(::Type{R}) where {T, N, S, A <: Array{S}, R <: Base.ReinterpretArray{T, N, S, A}} = StrideRank{ntuple(identity, Val{N}())}()
178178

179179
"""
180-
is_column_major(A) -> Val{true/false}()
180+
is_column_major(A) -> Val{true/false}()
181+
182+
Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
181183
"""
182184
is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A))
183185
is_column_major(::Nothing, ::Any) = Val{false}()
@@ -197,7 +199,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197199
@inline Base.getindex(::DenseDims{D}, i::Integer) where {D} = D[i]
198200
@inline Base.getindex(::DenseDims{D}, ::Val{I}) where {D,I} = DenseDims{permute(D, I)}()
199201
"""
200-
dense_dims(::Type{T}) -> NTuple{N,Bool}
202+
dense_dims(::Type{T}) -> NTuple{N,Bool}
201203
202204
Returns a tuple of indicators for whether each axis is dense.
203205
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -250,7 +252,7 @@ permute(t::NTuple{N}, I::NTuple{N,Int}) where {N} = ntuple(n -> t[I[n]], Val{N}(
250252
end
251253

252254
"""
253-
strides(A)
255+
strides(A) -> Tuple
254256
255257
Returns the strides of array `A`. If any strides are known at compile time,
256258
these should be returned as `Static` numbers. For example:
@@ -274,8 +276,196 @@ while still producing correct behavior when using valid cartesian indices, such
274276
strides(A) = Base.strides(A)
275277
strides(A, d) = strides(A)[to_dims(A, d)]
276278

279+
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
280+
out = Expr(:curly, :Tuple)
281+
for p in P
282+
push!(out.args, :(T.parameters[$p]))
283+
end
284+
Expr(:block, Expr(:meta, :inline), out)
285+
end
286+
287+
"""
288+
axes_types(::Type{T}[, d]) -> Type
289+
290+
Returns the type of the axes for `T`
291+
"""
292+
axes_types(x) = axes_types(typeof(x))
293+
axes_types(x, d) = axes_types(typeof(x), d)
294+
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)]
295+
function axes_types(::Type{T}) where {T}
296+
if parent_type(T) <: T
297+
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
298+
else
299+
return axes_types(parent_type(T))
300+
end
301+
end
302+
axes_types(::Type{T}) where {T<:Adjoint} = _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
303+
axes_types(::Type{T}) where {T<:Transpose} = _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
304+
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
305+
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
306+
end
307+
function axes_types(::Type{T}) where {T<:AbstractRange}
308+
if known_length(T) === nothing
309+
return Tuple{OptionallyStaticUnitRange{One,Int}}
310+
else
311+
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
312+
end
313+
end
314+
315+
@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
316+
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
317+
end
318+
@generated function _sub_axes_types(::Val{S}, ::Type{I}, ::Type{PI}) where {S,I<:Tuple,PI<:Tuple}
319+
out = Expr(:curly, :Tuple)
320+
d = 1
321+
for i in I.parameters
322+
ad = argdims(S, i)
323+
if ad > 0
324+
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
325+
d += ad
326+
else
327+
d += 1
328+
end
329+
end
330+
Expr(:block, Expr(:meta, :inline), out)
331+
end
332+
333+
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
334+
return _reinterpret_axes_types(axes_types(parent_type(T)), eltype(T), eltype(parent_type(T)))
335+
end
336+
@generated function _reinterpret_axes_types(::Type{I}, ::Type{T}, ::Type{S}) where {I<:Tuple,T,S}
337+
out = Expr(:curly, :Tuple)
338+
for i in 1:length(I.parameters)
339+
if i === 1
340+
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
341+
else
342+
push!(out.args, I.parameters[i])
343+
end
344+
end
345+
Expr(:block, Expr(:meta, :inline), out)
346+
end
347+
348+
349+
# These methods help handle identifying axes that dont' directly propagate from the
350+
# parent array axes. They may be worth making a formal part of the API, as they provide
351+
# a low traffic spot to change what axes_types produces.
352+
@inline function sub_axis_type(::Type{A}, ::Type{I}) where {A,I}
353+
if known_length(I) === nothing
354+
return OptionallyStaticUnitRange{One,Int}
355+
else
356+
return OptionallyStaticUnitRange{One,StaticInt{known_length(I)}}
357+
end
358+
end
359+
360+
@inline function reinterpret_axis_type(::Type{A}, ::Type{T}, ::Type{S}) where {A,T,S}
361+
if known_length(A) === nothing
362+
return OptionallyStaticUnitRange{One,Int}
363+
else
364+
return OptionallyStaticUnitRange{One,StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S)))}}
365+
end
366+
end
367+
277368
"""
278-
offsets(A)
369+
known_offsets(::Type{T}[, d]) -> Tuple
370+
371+
Returns a tuple of offset values known at compile time. If the offset of a given axis is
372+
not known at compile time `nothing` is returned its position.
373+
"""
374+
@inline known_offsets(x, d) = known_offsets(x)[to_dims(x, d)]
375+
known_offsets(x) = known_offsets(typeof(x))
376+
@generated function known_offsets(::Type{T}) where {T}
377+
out = Expr(:tuple)
378+
for p in axes_types(T).parameters
379+
push!(out.args, known_first(p))
380+
end
381+
return out
382+
end
383+
384+
"""
385+
known_size(::Type{T}[, d]) -> Tuple
386+
387+
Returns the size of each dimension for `T` known at compile time. If a dimension does not
388+
have a known size along a dimension then `nothing` is returned in its position.
389+
"""
390+
@inline known_size(x, d) = known_size(x)[to_dims(x, d)]
391+
known_size(x) = known_size(typeof(x))
392+
known_size(::Type{T}) where {T} = _known_size(axes_types(T))
393+
@generated function _known_size(::Type{Axs}) where {Axs<:Tuple}
394+
out = Expr(:tuple)
395+
for p in Axs.parameters
396+
push!(out.args, :(known_length($p)))
397+
end
398+
return Expr(:block, Expr(:meta, :inline), out)
399+
end
400+
401+
"""
402+
known_strides(::Type{T}[, d]) -> Tuple
403+
404+
Returns the strides of array `A` known at compile time. Any strides that are not known at
405+
compile time are represented by `nothing`.
406+
"""
407+
known_strides(x) = known_strides(typeof(x))
408+
known_strides(x, d) = known_strides(x)[to_dims(x, d)]
409+
known_strides(::Type{T}) where {T<:Vector} = (1,)
410+
@inline function known_strides(::Type{T}) where {T<:Adjoint{<:Any,<:AbstractVector}}
411+
strd = first(known_strides(parent_type(T)))
412+
return (strd, strd)
413+
end
414+
function known_strides(::Type{T}) where {T<:Adjoint}
415+
return permute(known_strides(parent_type(T)), Val{(2,1)}())
416+
end
417+
function known_strides(::Type{T}) where {T<:Transpose}
418+
return permute(known_strides(parent_type(T)), Val{(2,1)}())
419+
end
420+
@inline function known_strides(::Type{T}) where {T<:Transpose{<:Any,<:AbstractVector}}
421+
strd = first(known_strides(parent_type(T)))
422+
return (strd, strd)
423+
end
424+
@inline function known_strides(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
425+
return permute(known_strides(parent_type(T)), Val{I1}())
426+
end
427+
@inline function known_strides(::Type{T}) where {I1,T<:SubArray{<:Any,<:Any,<:Any,I1}}
428+
return _sub_strides(Val(ArrayStyle(T)), I1, Val(known_strides(parent_type(T))))
429+
end
430+
431+
@generated function _sub_strides(::Val{S}, ::Type{I}, ::Val{P}) where {S,I<:Tuple,P}
432+
out = Expr(:tuple)
433+
d = 1
434+
for i in I.parameters
435+
ad = argdims(S, i)
436+
if ad > 0
437+
push!(out.args, P[d])
438+
d += ad
439+
else
440+
d += 1
441+
end
442+
end
443+
Expr(:block, Expr(:meta, :inline), out)
444+
end
445+
446+
function known_strides(::Type{T}) where {T}
447+
if ndims(T) === 1
448+
return (1,)
449+
else
450+
return _known_strides(Val(Base.front(known_size(T))))
451+
end
452+
end
453+
@generated function _known_strides(::Val{S}) where {S}
454+
out = Expr(:tuple)
455+
N = length(S)
456+
push!(out.args, 1)
457+
for s in S
458+
if s === nothing || out.args[end] === nothing
459+
push!(out.args, nothing)
460+
else
461+
push!(out.args, out.args[end] * s)
462+
end
463+
end
464+
return Expr(:block, Expr(:meta, :inline), out)
465+
end
466+
467+
"""
468+
offsets(A) -> Tuple
279469
280470
Returns offsets of indices with respect to 0. If values are known at compile time,
281471
it should return them as `Static` numbers.
@@ -294,7 +484,7 @@ end
294484
strd = stride(parent(x), One())
295485
(strd, strd)
296486
end
297-
487+
298488
@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::Contiguous{C}) where {T,N,C}
299489
if C 0 || C > N
300490
return Expr(:block, Expr(:meta,:inline), :s)
@@ -325,15 +515,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325515
quote
326516
$(Expr(:meta,:inline))
327517
@inbounds $stup
328-
end
518+
end
329519
end
330520
end
331521

332-
@inline function offsets(x, i)
333-
inds = indices(x, i)
334-
start = known_first(inds)
335-
isnothing(start) ? first(inds) : StaticInt(start)
336-
end
522+
@inline offsets(x, i) = static_first(indices(x, i))
337523
# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338524
# Explicit tuple needed for inference.
339525
@generated function offsets(A::AbstractArray{<:Any,N}) where {N}
@@ -344,6 +530,7 @@ end
344530
Expr(:block, Expr(:meta, :inline), t)
345531
end
346532

533+
@inline size(v::AbstractVector) = (static_length(axes_types(v, 1)),)
347534
@inline size(B::Union{Transpose{T,A},Adjoint{T,A}}) where {T,A<:AbstractMatrix{T}} = permute(size(parent(B)), Val{(2,1)}())
348535
@inline size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A<:AbstractArray{T,N}} = permute(size(parent(B)), Val{I1}())
349536
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]

0 commit comments

Comments
 (0)