Skip to content

Commit 1ea0f3f

Browse files
authored
Prep for new release (#124)
* Prep for new release Mostly moving static.jl to Static * Fall back to Base.strides when define_strides(x) == false * Took out reference to StaticInt Documenting this is Static.jl's problem now * Use Static comparison operators * Use `defines_strides` to gate keep pointer assumptions * DenseArray gets CPUPointer * Add type stability check for axes_types(T, ::StaticInt) * No recursion on device Essentially equivalent to ``` function device(::Type{T}) where {T<:AbstractArray} if parent_type(T) <: T return T <: DenseArray ? CPUPointer() : CPUIndex() else if defines_strides(T) return device(parent_type(T)) else out = device(parent_type(T)) return out isa CPUPointer ? CPUIndex() : out end end end ``` * Refix fall back to Base.strides * Ensure that strides(::AbstractArray2) doesn't recurse
1 parent bc560cf commit 1ea0f3f

File tree

12 files changed

+70
-658
lines changed

12 files changed

+70
-658
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.1"
3+
version = "3.1.2"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
99
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
10+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1011

1112
[compat]
1213
IfElse = "0.1"
1314
Requires = "0.5, 1.0"
15+
Static = "0.1"
1416
julia = "1.2"
1517

1618
[extras]

README.md

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,6 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1)
208208

209209
Is the function `f` whitelisted for `LoopVectorization.@avx`?
210210

211-
## static(x)
212-
Returns a static form of `x`. If `x` is already in a static form then `x` is returned. If
213-
there is no static alternative for `x` then an error is thrown.
214-
215-
## StaticInt(N::Int)
216-
217-
Creates a static integer with value known at compile time. It is a number,
218-
supporting basic arithmetic. Many operations with two `StaticInt` integers
219-
will produce another `StaticInt` integer. If one of the arguments to a
220-
function call isn't static (e.g., `StaticInt(4) + 3`), then the `StaticInt`
221-
number will promote to a dynamic value.
222-
223211
# List of things to add
224212

225213
- https://github.com/JuliaLang/julia/issues/22216

src/ArrayInterface.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using IfElse
44
using Requires
55
using LinearAlgebra
66
using SparseArrays
7+
using Static
8+
using Static: Zero, One, nstatic, _get_tuple, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
9+
find_first_eq, permute, invariant_permutation
710
using Base.Cartesian
811

912
using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray,
@@ -32,7 +35,11 @@ parameterless_type(x::Type) = __parameterless_type(x)
3235
const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
3336
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
3437

35-
include("static.jl")
38+
@inline static_length(a::UnitRange{T}) where {T} = last(a) - first(a) + oneunit(T)
39+
@inline static_length(x) = Static.maybe_static(known_length, length, x)
40+
@inline static_first(x) = Static.maybe_static(known_first, first, x)
41+
@inline static_last(x) = Static.maybe_static(known_last, last, x)
42+
@inline static_step(x) = Static.maybe_static(known_step, step, x)
3643

3744
"""
3845
parent_type(::Type{T})
@@ -51,6 +58,15 @@ parent_type(::Type{Slice{T}}) where {T} = T
5158
parent_type(::Type{T}) where {T} = T
5259
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
5360

61+
"""
62+
has_parent(::Type{T}) -> StaticBool
63+
64+
Returns `True` if `parent_type(T)` a type unique to `T`.
65+
"""
66+
has_parent(::Type{T}) where {T} = _has_parent(parent_type(T), T)
67+
_has_parent(::Type{T}, ::Type{T}) where {T} = False()
68+
_has_parent(::Type{T1}, ::Type{T2}) where {T1,T2} = True()
69+
5470
"""
5571
known_length(::Type{T})
5672
@@ -616,13 +632,8 @@ device(A) = device(typeof(A))
616632
device(::Type) = nothing
617633
device(::Type{<:Tuple}) = CPUIndex()
618634
device(::Type{T}) where {T<:Array} = CPUPointer()
619-
device(::Type{T}) where {T<:AbstractArray} = CPUIndex()
620-
device(::Type{T}) where {T<:PermutedDimsArray} = device(parent_type(T))
621-
device(::Type{T}) where {T<:Transpose} = device(parent_type(T))
622-
device(::Type{T}) where {T<:Adjoint} = device(parent_type(T))
623-
device(::Type{T}) where {T<:ReinterpretArray} = device(parent_type(T))
624-
device(::Type{T}) where {T<:ReshapedArray} = device(parent_type(T))
625-
function device(::Type{T}) where {T<:SubArray}
635+
device(::Type{T}) where {T<:AbstractArray} = _device(has_parent(T), T)
636+
function _device(::True, ::Type{T}) where {T}
626637
if defines_strides(T)
627638
return device(parent_type(T))
628639
else
@@ -631,11 +642,14 @@ function device(::Type{T}) where {T<:SubArray}
631642
end
632643
_not_pointer(::CPUPointer) = CPUIndex()
633644
_not_pointer(x) = x
645+
_device(::False, ::Type{T}) where {T<:DenseArray} = CPUPointer()
646+
_device(::False, ::Type{T}) where {T} = CPUIndex()
634647

635648
"""
636649
defines_strides(::Type{T}) -> Bool
637650
638-
Is strides(::T) defined?
651+
Is strides(::T) defined? It is assumed that types returning `true` also return a valid
652+
pointer on `pointer(::T)`.
639653
"""
640654
defines_strides(x) = defines_strides(typeof(x))
641655
function defines_strides(::Type{T}) where {T}
@@ -787,9 +801,20 @@ Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))
787801
Base.axes(A::AbstractArray2) = ArrayInterface.axes(A)
788802
Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim)
789803

790-
Base.strides(A::AbstractArray2) = map(Int, ArrayInterface.strides(A))
804+
function Base.strides(A::AbstractArray2)
805+
defines_strides(A) || throw(MethodError(Base.strides, (A,)))
806+
return map(Int, ArrayInterface.strides(A))
807+
end
791808
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim))
792809

810+
function Base.IndexStyle(::Type{T}) where {T<:AbstractArray2}
811+
if parent_type(T) <: T
812+
return IndexCartesian()
813+
else
814+
return IndexStyle(parent_type(T))
815+
end
816+
end
817+
793818
function Base.length(A::AbstractArray2)
794819
len = known_length(A)
795820
if len === nothing

src/axes.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@ function axes_types(::Type{T}) where {T}
3535
return axes_types(parent_type(T))
3636
end
3737
end
38+
axes_types(::Type{LinearIndices{N,R}}) where {N,R} = R
39+
axes_types(::Type{CartesianIndices{N,R}}) where {N,R} = R
3840
function axes_types(::Type{T}) where {T<:VecAdjTrans}
3941
return Tuple{OptionallyStaticUnitRange{One,One},axes_types(parent_type(T), One())}
4042
end
4143
function axes_types(::Type{T}) where {T<:MatAdjTrans}
42-
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
44+
return eachop_tuple(_get_tuple, axes_types(parent_type(T)); iterator=to_parent_dims(T))
4345
end
4446
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
45-
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
47+
return eachop_tuple(_get_tuple, axes_types(parent_type(T)); iterator=to_parent_dims(T))
4648
end
4749
function axes_types(::Type{T}) where {T<:AbstractRange}
4850
if known_length(T) === nothing
@@ -59,7 +61,7 @@ _int_or_static_int(::Nothing) = Int
5961
_int_or_static_int(x::Int) = StaticInt{x}
6062

6163
@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
62-
return eachop_tuple(_sub_axis_type, T, to_parent_dims(T))
64+
return eachop_tuple(_sub_axis_type, T; iterator=to_parent_dims(T))
6365
end
6466
@inline function _sub_axis_type(::Type{A}, dim::StaticInt) where {T,N,P,I,A<:SubArray{T,N,P,I}}
6567
return OptionallyStaticUnitRange{
@@ -73,12 +75,12 @@ function axes_types(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
7375
if sizeof(S) === sizeof(T)
7476
return axes_types(A)
7577
elseif sizeof(S) > sizeof(T)
76-
return eachop_tuple(_reshaped_axis_type, R, to_parent_dims(R))
78+
return eachop_tuple(_reshaped_axis_type, R; iterator=to_parent_dims(R))
7779
else
78-
return eachop_tuple(axes_types, A, to_parent_dims(R))
80+
return eachop_tuple(axes_types, A; iterator=to_parent_dims(R))
7981
end
8082
else
81-
return eachop_tuple(_non_reshaped_axis_type, R, to_parent_dims(R))
83+
return eachop_tuple(_non_reshaped_axis_type, R; iterator=to_parent_dims(R))
8284
end
8385
end
8486

src/dimensions.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,12 @@ end
182182
if invariant_permutation(perm, perm) isa True
183183
return dimnames(parent_type(T))
184184
else
185-
return eachop(dimnames, parent_type(T), perm)
185+
return eachop(dimnames, parent_type(T); iterator=perm)
186186
end
187187
end
188188
end
189189
function dimnames(::Type{T}) where {T<:SubArray}
190-
return eachop(dimnames, parent_type(T), to_parent_dims(T))
190+
return eachop(dimnames, parent_type(T); iterator=to_parent_dims(T))
191191
end
192192

193193
_to_int(x::Integer) = Int(x)
@@ -241,7 +241,7 @@ end
241241
inds::Tuple
242242
) where {N}
243243

244-
out = eachop(((x, nd, inds), i) -> order_named_inds(x, nd, inds, i), (x, nd, inds), nstatic(Val(N)))
244+
out = eachop(order_named_inds, x, nd, inds; iterator=nstatic(Val(N)))
245245
_order_named_inds_check(out, length(nd))
246246
return out
247247
end

src/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
3131
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
3232
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
3333
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
34-
return eachop(_argdims, s, T, nstatic(Val(N)))
34+
return eachop(_argdims, s, T; iterator=nstatic(Val(N)))
3535
end
3636

3737
"""
@@ -183,7 +183,7 @@ can_flatten(::Type{A}, ::Type{T}) where {A,T<:CartesianIndices} = true
183183
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:AbstractArray{Bool,N}} = N > 1
184184
can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:CartesianIndex{N}} = true
185185
function can_flatten(::Type{A}, ::Type{T}) where {A,N,T<:Tuple{Vararg{Any,N}}}
186-
return any(eachop(_can_flat, A, T, nstatic(Val(N))))
186+
return any(eachop(_can_flat, A, T; iterator=nstatic(Val(N))))
187187
end
188188
function _can_flat(::Type{A}, ::Type{T}, i::StaticInt) where {A,T}
189189
if can_flatten(A, _get_tuple(T, i)) === true

src/size.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function size(a::A) where {A}
2222
end
2323
#size(a::AbstractVector) = (size(a, One()),)
2424

25-
size(x::SubArray) = eachop(_sub_size, x.indices, to_parent_dims(x))
25+
size(x::SubArray) = eachop(_sub_size, x.indices; iterator=to_parent_dims(x))
2626
_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim))
2727

2828
@inline size(B::VecAdjTrans) = (One(), length(parent(B)))
@@ -81,7 +81,7 @@ Returns the size of each dimension for `T` known at compile time. If a dimension
8181
have a known size along a dimension then `nothing` is returned in its position.
8282
"""
8383
known_size(x) = known_size(typeof(x))
84-
known_size(::Type{T}) where {T} = eachop(known_size, T, nstatic(Val(ndims(T))))
84+
known_size(::Type{T}) where {T} = eachop(known_size, T; iterator=nstatic(Val(ndims(T))))
8585

8686
"""
8787
known_size(::Type{T}, dim)

0 commit comments

Comments
 (0)