Skip to content

Commit db1c9a2

Browse files
committed
Make working with axes easy
These methods are aimed predominantly at making compile time info about axes easily available. This means that information about `strides` and `offsets` can be accessed directly from the array type instead of from construction of a tuple of elements. Also cleaned up some of the docs a bit
1 parent 788f44e commit db1c9a2

File tree

3 files changed

+270
-22
lines changed

3 files changed

+270
-22
lines changed

src/ArrayInterface.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ parameterless_type(x::Type) = __parameterless_type(x)
1313
"""
1414
parent_type(::Type{T})
1515
16-
Returns the parent array that `x` wraps.
16+
Returns the parent array that type `T` wraps.
1717
"""
1818
parent_type(x) = parent_type(typeof(x))
1919
parent_type(::Type{<:SubArray{T,N,P}}) where {T,N,P} = P
@@ -704,6 +704,12 @@ end
704704
end
705705
end
706706

707+
include("static.jl")
708+
include("ranges.jl")
709+
include("dimensions.jl")
710+
include("indexing.jl")
711+
include("stridelayout.jl")
712+
707713
function __init__()
708714

709715
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
@@ -746,6 +752,10 @@ function __init__()
746752
stride_rank(::Type{T}) where {N, T <: StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}()
747753
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}()
748754
defines_strides(::Type{<:StaticArrays.MArray}) = true
755+
756+
@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
757+
return Tuple{[StaticArrays.SOneTo{s} for s in S.parameters]...}
758+
end
749759
@generated function size(A::StaticArrays.StaticArray{S}) where {S}
750760
t = Expr(:tuple); Sp = S.parameters
751761
for n in 1:length(Sp)
@@ -896,10 +906,4 @@ function __init__()
896906
end
897907
end
898908

899-
include("static.jl")
900-
include("ranges.jl")
901-
include("dimensions.jl")
902-
include("indexing.jl")
903-
include("stridelayout.jl")
904-
905909
end

src/stridelayout.jl

Lines changed: 218 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474

7575

7676
"""
77-
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
77+
contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7878
7979
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
8080
"""
@@ -84,14 +84,14 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
8484
Base.@pure contiguous_axis_indicator(::Contiguous{N}, ::Val{D}) where {N,D} = ntuple(d -> Val{d == N}(), Val{D}())
8585

8686
"""
87-
If the contiguous dimension is not the dimension with `Stride_rank{1}`:
87+
If the contiguous dimension is not the dimension with `StrideRank{1}`:
8888
"""
8989
struct ContiguousBatch{N} end
9090
Base.@pure ContiguousBatch(N::Int) = ContiguousBatch{N}()
9191
_get(::ContiguousBatch{N}) where {N} = N
9292

9393
"""
94-
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
94+
contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
9595
9696
Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
9797
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
@@ -126,7 +126,7 @@ Base.collect(::StrideRank{R}) where {R} = collect(R)
126126
@inline Base.getindex(::StrideRank{R}, ::Val{I}) where {R,I} = StrideRank{permute(R, I)}()
127127

128128
"""
129-
rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
129+
rank_to_sortperm(::StrideRank) -> NTuple{N,Int}
130130
131131
Returns the `sortperm` of the stride ranks.
132132
"""
@@ -197,7 +197,7 @@ Base.@pure DenseDims(D::NTuple{<:Any,Bool}) = DenseDims{D}()
197197
@inline Base.getindex(::DenseDims{D}, i::Integer) where {D} = D[i]
198198
@inline Base.getindex(::DenseDims{D}, ::Val{I}) where {D,I} = DenseDims{permute(D, I)}()
199199
"""
200-
dense_dims(::Type{T}) -> NTuple{N,Bool}
200+
dense_dims(::Type{T}) -> NTuple{N,Bool}
201201
202202
Returns a tuple of indicators for whether each axis is dense.
203203
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
@@ -274,8 +274,217 @@ while still producing correct behavior when using valid cartesian indices, such
274274
strides(A) = Base.strides(A)
275275
strides(A, d) = strides(A)[to_dims(A, d)]
276276

277+
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
278+
out = Expr(:curly, :Tuple)
279+
for p in P
280+
push!(out.args, :(T.parameters[$p]))
281+
end
282+
Expr(:block, Expr(:meta, :inline), out)
283+
end
284+
285+
"""
286+
axes_types(::Type{T}[, d]) -> Type
287+
288+
Returns the type of the axes for `T`
289+
"""
290+
axes_types(x) = axes_types(typeof(x))
291+
axes_types(x, d) = axes_types(typeof(x), d)
292+
@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)]
293+
function axes_types(::Type{T}) where {T}
294+
if parent_type(T) <: T
295+
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
296+
else
297+
return axes_types(parent_type(T))
298+
end
299+
end
300+
axes_types(::Type{T}) where {T<:Adjoint} = _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
301+
axes_types(::Type{T}) where {T<:Transpose} = _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
302+
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
303+
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
304+
end
305+
function axes_types(::Type{T}) where {T<:OptionallyStaticRange}
306+
if known_length(T) === nothing
307+
return Tuple{OptionallyStaticUnitRange{One,Int}}
308+
else
309+
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T) - 1}}}
310+
end
311+
end
312+
313+
@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
314+
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
315+
end
316+
@generated function _sub_axes_types(::Val{S}, ::Type{I}, ::Type{PI}) where {S,I<:Tuple,PI<:Tuple}
317+
out = Expr(:curly, :Tuple)
318+
d = 1
319+
for i in I.parameters
320+
ad = argdims(S, i)
321+
if ad > 0
322+
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
323+
d += ad
324+
else
325+
d += 1
326+
end
327+
end
328+
Expr(:block, Expr(:meta, :inline), out)
329+
end
330+
331+
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
332+
return _reinterpret_axes_types(axes_type(parent_type(T)), eltype(T), eltype(parent_type(T)))
333+
end
334+
@generated function _reinterpret_axes_types(::Type{I}, ::Type{T}, ::Type{S}) where {I<:Tuple,T,S}
335+
out = Expr(:curly, :Tuple)
336+
for i in 1:length(T.parameters)
337+
if i === 1
338+
push!(out.args, :(reinterpret_axis_type($(I.parameters[1]), $T, $S)))
339+
else
340+
# FIXME double check this once I've slept
341+
push!(out.args, :($(I.parameters[i])))
342+
end
343+
end
344+
Expr(:block, Expr(:meta, :inline), out)
345+
end
346+
347+
348+
# These methods help handle identifying axes that dont' directly propagate from the
349+
# parent array axes. They may be worth making a formal part of the API, as they provide
350+
# a low traffic spot to change what axes_types produces.
351+
@inline function sub_axis_type(::Type{A}, ::Type{I}) where {A,I}
352+
if known_length(I) === nothing
353+
return OptionallyStaticUnitRange{One,Int}
354+
else
355+
return OptionallyStaticUnitRange{One,StaticInt{known_length(I)}}
356+
end
357+
end
358+
359+
@inline function reinterpret_axis_type(::Type{A}, ::Type{T}, ::Type{S}) where {A,T,S}
360+
if known_length(A) === nothing
361+
return OptionallyStaticUnitRange{One,Int}
362+
else
363+
return OptionallyStaticUnitRange{One,StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S))) - 1}}
364+
end
365+
end
366+
367+
"""
368+
known_offsets(::Type{T}[, d]) -> Tuple
369+
370+
Returns a tuple of offset values known at compile time. If the offset of a given axis is
371+
not known at compile time `nothing` is returned its position.
372+
"""
373+
@inline known_offsets(x, d) = known_offsets(x)[to_dims(x, d)]
374+
known_offsets(x) = known_offsets(typeof(x))
375+
@generated function known_offsets(::Type{T}) where {T}
376+
out = Expr(:tuple)
377+
for p in axes_types(T).parameters
378+
push!(out.args, known_first(p))
379+
end
380+
return out
381+
end
382+
383+
"""
384+
known_size(::Type{T}[, d]) -> Tuple
385+
386+
Returns the size of each dimension for `T` known at compile time. If a dimension does not
387+
have a known size along a dimension then `nothing` is returned in its position.
388+
"""
389+
@inline known_size(x, d) = known_size(x)[to_dims(x, d)]
390+
known_size(x) = known_size(typeof(x))
391+
known_size(::Type{T}) where {T} = _known_size(axes_types(T))
392+
@generated function _known_size(::Type{Axs}) where {Axs<:Tuple}
393+
out = Expr(:tuple)
394+
for p in Axs.parameters
395+
push!(out.args, :(known_length($p)))
396+
end
397+
return Expr(:block, Expr(:meta, :inline), out)
398+
end
399+
277400
"""
278-
offsets(A)
401+
known_strides(::Type{T}[, d]) -> Tuple
402+
"""
403+
known_strides(x) = known_strides(typeof(x))
404+
known_strides(x, d) = known_strides(x)[to_dims(x, d)]
405+
known_strides(::Type{T}) where {T<:Vector} = (1,)
406+
@inline function known_strides(::Type{T}) where {T<:Adjoint{<:Any,<:AbstractVector}}
407+
strd = first(known_strides(parent_type(T)))
408+
return (strd, strd)
409+
end
410+
function known_strides(::Type{T}) where {T<:Adjoint}
411+
return permute(known_strides(parent_type(T)), Val{(2,1)}())
412+
end
413+
function known_strides(::Type{T}) where {T<:Transpose}
414+
return permute(known_strides(parent_type(T)), Val{(2,1)}())
415+
end
416+
@inline function known_strides(::Type{T}) where {T<:Transpose{<:Any,<:AbstractVector}}
417+
strd = first(known_strides(parent_type(T)))
418+
return (strd, strd)
419+
end
420+
@inline function known_strides(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
421+
return permute(known_strides(parent_type(T)), Val{I1}())
422+
end
423+
@inline function known_strides(::Type{T}) where {I1,T<:SubArray{<:Any,<:Any,<:Any,I1}}
424+
return _sub_strides(Val(ArrayStyle(T)), I1, Val(known_strides(parent_type(T))))
425+
end
426+
427+
@generated function _sub_strides(::Val{S}, ::Type{I}, ::Val{P}) where {S,I<:Tuple,P}
428+
out = Expr(:tuple)
429+
d = 1
430+
for i in I.parameters
431+
ad = argdims(S, i)
432+
if ad > 0
433+
push!(out.args, P[d])
434+
d += ad
435+
else
436+
d += 1
437+
end
438+
end
439+
Expr(:block, Expr(:meta, :inline), out)
440+
end
441+
442+
function known_strides(::Type{T}) where {T}
443+
if ndims(T) === 1
444+
return (1,)
445+
else
446+
return _known_strides(Val(Base.front(known_size(T))))
447+
end
448+
end
449+
@generated function _known_strides(::Val{S}) where {S}
450+
out = Expr(:tuple)
451+
N = length(S)
452+
push!(out.args, 1)
453+
for s in S
454+
if s === nothing || out.args[end] === nothing
455+
push!(out.args, nothing)
456+
else
457+
push!(out.args, out.args[end] * s)
458+
end
459+
end
460+
return Expr(:block, Expr(:meta, :inline), out)
461+
end
462+
463+
#=
464+
465+
function strides(a::ReinterpretArray)
466+
a.parent isa StridedArray || ArgumentError("Parent must be strided.") |> throw
467+
size_to_strides(1, size(a)...)
468+
end
469+
470+
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
471+
@generated function _strides(_::Base.ReinterpretArray{T, N, S, A, true}, s::NTuple{N}, ::Contiguous{1}) where {T, N, S, D, A <: Array{S,D}}
472+
stup = Expr(:tuple, :(One()))
473+
if D < N
474+
push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T))))
475+
end
476+
for n ∈ 2+(D < N):N
477+
push!(stup.args, Expr(:ref, :s, n))
478+
end
479+
quote
480+
$(Expr(:meta,:inline))
481+
@inbounds $stup
482+
end
483+
end
484+
=#
485+
486+
"""
487+
offsets(A) -> Tuple
279488
280489
Returns offsets of indices with respect to 0. If values are known at compile time,
281490
it should return them as `Static` numbers.
@@ -294,7 +503,7 @@ end
294503
strd = stride(parent(x), One())
295504
(strd, strd)
296505
end
297-
506+
298507
@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::Contiguous{C}) where {T,N,C}
299508
if C 0 || C > N
300509
return Expr(:block, Expr(:meta,:inline), :s)
@@ -325,15 +534,11 @@ if VERSION ≥ v"1.6.0-DEV.1581"
325534
quote
326535
$(Expr(:meta,:inline))
327536
@inbounds $stup
328-
end
537+
end
329538
end
330539
end
331540

332-
@inline function offsets(x, i)
333-
inds = indices(x, i)
334-
start = known_first(inds)
335-
isnothing(start) ? first(inds) : StaticInt(start)
336-
end
541+
@inline offsets(x, i) = static_first(indices(x, i))
337542
# @inline offsets(A::AbstractArray{<:Any,N}) where {N} = ntuple(n -> offsets(A, n), Val{N}())
338543
# Explicit tuple needed for inference.
339544
@generated function offsets(A::AbstractArray{<:Any,N}) where {N}

0 commit comments

Comments
 (0)