Skip to content

Commit 1145ce3

Browse files
authored
More efficient offsets (#127)
* More efficient offsets * Small doc fix to trigger CI
1 parent 1ea0f3f commit 1145ce3

File tree

8 files changed

+38
-24
lines changed

8 files changed

+38
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1212
[compat]
1313
IfElse = "0.1"
1414
Requires = "0.5, 1.0"
15-
Static = "0.1"
15+
Static = "0.2"
1616
julia = "1.2"
1717

1818
[extras]

src/ArrayInterface.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ end
666666
defines_strides(::Type{<:BitArray}) = true
667667

668668
"""
669-
can_avx(f)
669+
can_avx(f) -> Bool
670670
671671
Returns `true` if the function `f` is guaranteed to be compatible with
672672
`LoopVectorization.@avx` for supported element and array types. While a return
@@ -1071,7 +1071,6 @@ function __init__()
10711071
@require OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" begin
10721072
size(A::OffsetArrays.OffsetArray) = size(parent(A))
10731073
strides(A::OffsetArrays.OffsetArray) = strides(parent(A))
1074-
# offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A)))
10751074
function parent_type(
10761075
::Type{O},
10771076
) where {T,N,A<:AbstractArray{T,N},O<:OffsetArrays.OffsetArray{T,N,A}}
@@ -1084,8 +1083,16 @@ function __init__()
10841083
function contiguous_batch_size(::Type{A}) where {A<:OffsetArrays.OffsetArray}
10851084
return contiguous_batch_size(parent_type(A))
10861085
end
1087-
stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray} =
1088-
stride_rank(parent_type(A))
1086+
1087+
function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
1088+
return OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
1089+
end
1090+
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
1091+
return Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterface.parent_type(T))
1092+
end
1093+
function stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray}
1094+
return stride_rank(parent_type(A))
1095+
end
10891096
ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
10901097
ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
10911098
function ArrayInterface.device(::Type{T}) where {T<:OffsetArrays.OffsetArray}

src/axes.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ function axes_types(::Type{T}) where {T<:VecAdjTrans}
4141
return Tuple{OptionallyStaticUnitRange{One,One},axes_types(parent_type(T), One())}
4242
end
4343
function axes_types(::Type{T}) where {T<:MatAdjTrans}
44-
return eachop_tuple(_get_tuple, axes_types(parent_type(T)); iterator=to_parent_dims(T))
44+
return eachop_tuple(_get_tuple, to_parent_dims(T), axes_types(parent_type(T)))
4545
end
4646
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
47-
return eachop_tuple(_get_tuple, axes_types(parent_type(T)); iterator=to_parent_dims(T))
47+
return eachop_tuple(_get_tuple, to_parent_dims(T), axes_types(parent_type(T)))
4848
end
4949
function axes_types(::Type{T}) where {T<:AbstractRange}
5050
if known_length(T) === nothing
@@ -61,7 +61,7 @@ _int_or_static_int(::Nothing) = Int
6161
_int_or_static_int(x::Int) = StaticInt{x}
6262

6363
@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
64-
return eachop_tuple(_sub_axis_type, T; iterator=to_parent_dims(T))
64+
return eachop_tuple(_sub_axis_type, to_parent_dims(T), T)
6565
end
6666
@inline function _sub_axis_type(::Type{A}, dim::StaticInt) where {T,N,P,I,A<:SubArray{T,N,P,I}}
6767
return OptionallyStaticUnitRange{
@@ -75,12 +75,12 @@ function axes_types(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
7575
if sizeof(S) === sizeof(T)
7676
return axes_types(A)
7777
elseif sizeof(S) > sizeof(T)
78-
return eachop_tuple(_reshaped_axis_type, R; iterator=to_parent_dims(R))
78+
return eachop_tuple(_reshaped_axis_type, to_parent_dims(R), R)
7979
else
80-
return eachop_tuple(axes_types, A; iterator=to_parent_dims(R))
80+
return eachop_tuple(axes_types, to_parent_dims(R), A)
8181
end
8282
else
83-
return eachop_tuple(_non_reshaped_axis_type, R; iterator=to_parent_dims(R))
83+
return eachop_tuple(_non_reshaped_axis_type, to_parent_dims(R), R)
8484
end
8585
end
8686

src/dimensions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,12 @@ end
182182
if invariant_permutation(perm, perm) isa True
183183
return dimnames(parent_type(T))
184184
else
185-
return eachop(dimnames, parent_type(T); iterator=perm)
185+
return eachop(dimnames, perm, parent_type(T))
186186
end
187187
end
188188
end
189189
function dimnames(::Type{T}) where {T<:SubArray}
190-
return eachop(dimnames, parent_type(T); iterator=to_parent_dims(T))
190+
return eachop(dimnames, to_parent_dims(T), parent_type(T))
191191
end
192192

193193
_to_int(x::Integer) = Int(x)
@@ -241,7 +241,7 @@ end
241241
inds::Tuple
242242
) where {N}
243243

244-
out = eachop(order_named_inds, x, nd, inds; iterator=nstatic(Val(N)))
244+
out = eachop(order_named_inds, nstatic(Val(N)), x, nd, inds)
245245
_order_named_inds_check(out, length(nd))
246246
return out
247247
end

src/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
3131
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
3232
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
3333
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
34-
return eachop(_argdims, s, T; iterator=nstatic(Val(N)))
34+
return eachop(_argdims, nstatic(Val(N)), s, T)
3535
end
3636

3737
"""
@@ -183,7 +183,7 @@ can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true
183183
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1
184184
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true
185185
function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}}
186-
return any(eachop(_can_flat, A, T; iterator=nstatic(Val(N))))
186+
return any(eachop(_can_flat, nstatic(Val(N)), A, T))
187187
end
188188
function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T}
189189
if can_flatten(A, _get_tuple(T, i)) === true

src/ranges.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,6 @@ const OptionallyStaticRange = Union{<:OptionallyStaticUnitRange,<:OptionallyStat
401401
Base.eachindex(r::OptionallyStaticRange) = r
402402
@inline Base.iterate(r::OptionallyStaticRange) = (fi = Int(first(r)); (fi, fi))
403403

404-
405404
Base.to_shape(x::OptionallyStaticRange) = length(x)
406405
Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = length(x)
407406

src/size.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function size(a::A) where {A}
2222
end
2323
#size(a::AbstractVector) = (size(a, One()),)
2424

25-
size(x::SubArray) = eachop(_sub_size, x.indices; iterator=to_parent_dims(x))
25+
size(x::SubArray) = eachop(_sub_size, to_parent_dims(x), x.indices)
2626
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))
2727

2828
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@@ -81,7 +81,7 @@ Returns the size of each dimension for `T` known at compile time. If a dimension
8181
have a known size along a dimension then `nothing` is returned in its position.
8282
"""
8383
known_size(x) = known_size(typeof(x))
84-
known_size(::Type{T}) where {T} = eachop(known_size, T; iterator=nstatic(Val(ndims(T))))
84+
known_size(::Type{T}) where {T} = eachop(known_size, nstatic(Val(ndims(T))), T)
8585

8686
"""
8787
known_size(::Type{T}, dim)

src/stridelayout.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ stride_preserving_index(::Type{T}) where {T<:AbstractRange} = True()
99
stride_preserving_index(::Type{T}) where {T<:Int} = True()
1010
stride_preserving_index(::Type{T}) where {T} = False()
1111
function stride_preserving_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
12-
if all(eachop(_stride_preserving_index, T; iterator=nstatic(Val(N))))
12+
if all(eachop(_stride_preserving_index, nstatic(Val(N)), T))
1313
return True()
1414
else
1515
return False()
@@ -27,9 +27,17 @@ it should return them as `Static` numbers.
2727
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.
2828
"""
2929
@inline offsets(x, i) = static_first(indices(x, i))
30-
# Explicit tuple needed for inference.
31-
offsets(x) = eachop(offsets, x; iterator=nstatic(Val(ndims(x))))
3230
offsets(::Tuple) = (One(),)
31+
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
32+
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
33+
start = known_first(axes_types(X, dim))
34+
if start === nothing
35+
return first(axes(x, dim))
36+
else
37+
return static(start)
38+
end
39+
end
40+
3341

3442
"""
3543
contiguous_axis(::Type{T}) -> StaticInt{N}
@@ -346,7 +354,7 @@ end
346354

347355
known_offsets(x) = known_offsets(typeof(x))
348356
function known_offsets(::Type{T}) where {T}
349-
return eachop(_known_offsets, axes_types(T); iterator=nstatic(Val(ndims(T))))
357+
return eachop(_known_offsets, nstatic(Val(ndims(T))), axes_types(T))
350358
end
351359
_known_offsets(::Type{T}, dim::StaticInt) where {T} = known_first(_get_tuple(T, dim))
352360

@@ -433,7 +441,7 @@ end
433441

434442
getmul(x::Tuple, y::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) * getfield(y, i)
435443
function strides(A::SubArray)
436-
return eachop(getmul, map(maybe_static_step, A.indices), strides(parent(A)); iterator=to_parent_dims(A))
444+
return eachop(getmul, to_parent_dims(A), map(maybe_static_step, A.indices), strides(parent(A)))
437445
end
438446

439447
maybe_static_step(x::AbstractRange) = static_step(x)

0 commit comments

Comments
 (0)