Skip to content

Commit 768548b

Browse files
authored
Move NDIndex to Static (#172)
1 parent 22bb669 commit 768548b

File tree

5 files changed

+6
-184
lines changed

5 files changed

+6
-184
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.17"
3+
version = "3.1.18"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
@@ -12,7 +12,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1212
[compat]
1313
IfElse = "0.1"
1414
Requires = "0.5, 1.0"
15-
Static = "0.2"
15+
Static = "0.3"
1616
julia = "1.2"
1717

1818
[extras]

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
4949
@inline static_last(x) = Static.maybe_static(known_last, last, x)
5050
@inline static_step(x) = Static.maybe_static(known_step, step, x)
5151

52-
include("ndindex.jl")
5352
include("array_index.jl")
5453

5554
"""
@@ -94,6 +93,7 @@ known_length(::Type{T}) where {T<:Slice} = known_length(parent_type(T))
9493
known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N
9594
known_length(::Type{T}) where {Itr,T<:Base.Generator{Itr}} = known_length(Itr)
9695
known_length(::Type{<:Number}) = 1
96+
known_length(::Type{NDIndex{N,I}}) where {N,I} = N
9797
function known_length(::Type{T}) where {T}
9898
if parent_type(T) <: T
9999
return nothing

src/ndindex.jl

Lines changed: 0 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,178 +1 @@
11

2-
"""
3-
NDIndex(i, j, k...) -> I
4-
NDIndex((i, j, k...)) -> I
5-
6-
A multidimensional index that refers to a single element. Each dimension is represented by
7-
a single `Int` or `StaticInt`.
8-
9-
```julia
10-
julia> using ArrayInterface: NDIndex
11-
12-
julia> using Static
13-
14-
julia> i = NDIndex(static(1), 2, static(3))
15-
NDIndex(static(1), 2, static(3))
16-
17-
julia> i[static(1)]
18-
static(1)
19-
20-
julia> i[1]
21-
1
22-
23-
```
24-
"""
25-
struct NDIndex{N,I<:Tuple{Vararg{Any,N}}} <: AbstractCartesianIndex{N}
26-
index::I
27-
28-
global _NDIndex(index::Tuple{Vararg{Any,N}}) where {N} = new{N,typeof(index)}(index)
29-
30-
function NDIndex{N,I}(index::I) where {N,I<:Tuple{Vararg{Integer,N}}}
31-
for i in index
32-
(i isa Int) || i isa StaticInt || throw(ArgumentError("NDIndex does not support values of type $(typeof(i))"))
33-
end
34-
return new{N,I}(index)
35-
end
36-
37-
NDIndex{N}(index::Tuple) where {N} = _ndindex(static(N), _flatten(index...))
38-
NDIndex{N}(index...) where {N} = _ndindex(static(N), _flatten(index...))
39-
40-
NDIndex{0}(::Tuple{}) = new{0,Tuple{}}(())
41-
NDIndex{0}() = NDIndex{0}(())
42-
43-
NDIndex(index::Tuple) = _NDIndex(_flatten(index...))
44-
NDIndex(index...) = _NDIndex(_flatten(index...))
45-
end
46-
47-
_ndindex(n::StaticInt{N}, i::Tuple{Vararg{Union{Int,StaticInt},N}}) where {N} = _NDIndex(i)
48-
function _ndindex(n::StaticInt{N}, i::Tuple{Vararg{Any,M}}) where {N,M}
49-
M > N && throw(ArgumentError("input tuple of length $M, requested $N"))
50-
return _NDIndex(_fill_to_length(i, n))
51-
end
52-
_fill_to_length(x::Tuple{Vararg{Any,N}}, n::StaticInt{N}) where {N} = x
53-
@inline function _fill_to_length(x::Tuple{Vararg{Any,M}}, n::StaticInt{N}) where {M,N}
54-
return _fill_to_length((x..., static(1)), n)
55-
end
56-
57-
_flatten(i::StaticInt{N}) where {N} = (i,)
58-
_flatten(i::Integer) = (Int(i),)
59-
_flatten(i::Base.AbstractCartesianIndex) = _flatten(Tuple(i)...)
60-
@inline _flatten(i::Integer, I...) = (canonicalize(i), _flatten(I...)...)
61-
@inline function _flatten(i::Base.AbstractCartesianIndex, I...)
62-
return (_flatten(Tuple(i)...)..., _flatten(I...)...)
63-
end
64-
Base.Tuple(index::NDIndex) = index.index
65-
66-
Static.dynamic(x::NDIndex) = CartesianIndex(dynamic(Tuple(x)))
67-
Static.static(x::CartesianIndex) = _NDIndex(static(Tuple(x)))
68-
Static.known(::Type{NDIndex{N,I}}) where {N,I} = known(I)
69-
70-
Base.show(io::IO, i::NDIndex) = (print(io, "NDIndex"); show(io, Tuple(i)))
71-
72-
# length
73-
Base.length(::NDIndex{N}) where {N} = N
74-
Base.length(::Type{NDIndex{N,I}}) where {N,I} = N
75-
known_length(::Type{NDIndex{N,I}}) where {N,I} = N
76-
77-
# indexing
78-
@propagate_inbounds function getindex(x::NDIndex{N,T}, i::Int)::Int where {N,T}
79-
return Int(getfield(Tuple(x), i))
80-
end
81-
@propagate_inbounds function getindex(x::NDIndex{N,T}, i::StaticInt{I}) where {N,T,I}
82-
return getfield(Tuple(x), I)
83-
end
84-
@propagate_inbounds Base.getindex(x::NDIndex, i::Integer) = ArrayInterface.getindex(x, i)
85-
86-
# Base.get(A::AbstractArray, I::CartesianIndex, default) = get(A, I.I, default)
87-
# eltype(::Type{T}) where {T<:CartesianIndex} = eltype(fieldtype(T, :I))
88-
89-
Base.setindex(x::NDIndex, i, j) = NDIndex(Base.setindex(Tuple(x), i, j))
90-
91-
# equality
92-
Base.:(==)(x::NDIndex{N}, y::NDIndex{N}) where N = Tuple(x) == Tuple(y)
93-
94-
# zeros and ones
95-
Base.zero(::NDIndex{N}) where {N} = zero(NDIndex{N})
96-
Base.zero(::Type{NDIndex{N}}) where {N} = _NDIndex(ntuple(_ -> static(0), Val(N)))
97-
Base.oneunit(::NDIndex{N}) where {N} = oneunit(NDIndex{N})
98-
Base.oneunit(::Type{NDIndex{N}}) where {N} = _NDIndex(ntuple(_ -> static(1), Val(N)))
99-
100-
@inline function Base.split(i::NDIndex, V::Val)
101-
i, j = split(Tuple(i), V)
102-
return NDIndex(i), NDIndex(j)
103-
end
104-
105-
# arithmetic, min/max
106-
@inline Base.:(-)(i::NDIndex{N}) where {N} = NDIndex{N}(map(-, Tuple(i)))
107-
@inline function Base.:(+)(i1::NDIndex{N}, i2::NDIndex{N}) where {N}
108-
return _NDIndex(map(+, Tuple(i1), Tuple(i2)))
109-
end
110-
@inline function Base.:(-)(i1::NDIndex{N}, i2::NDIndex{N}) where {N}
111-
return _NDIndex(map(-, Tuple(i1), Tuple(i2)))
112-
end
113-
@inline function Base.min(i1::NDIndex{N}, i2::NDIndex{N}) where {N}
114-
return _NDIndex(map(min, Tuple(i1), Tuple(i2)))
115-
end
116-
@inline function Base.max(i1::NDIndex{N}, i2::NDIndex{N}) where {N}
117-
return _NDIndex(map(max, Tuple(i1), Tuple(i2)))
118-
end
119-
@inline Base.:(*)(a::Integer, i::NDIndex{N}) where {N} = _NDIndex(map(x->a*x, Tuple(i)))
120-
@inline Base.:(*)(i::NDIndex, a::Integer) = *(a, i)
121-
122-
Base.CartesianIndex(x::NDIndex) = CartesianIndex(Tuple(x))
123-
124-
# comparison
125-
@inline function Base.isless(x::NDIndex{N}, y::NDIndex{N}) where {N}
126-
return Bool(_isless(static(0), Tuple(x), Tuple(y)))
127-
end
128-
129-
Static.lt(x::NDIndex{N}, y::NDIndex{N}) where {N} = _isless(static(0), Tuple(x), Tuple(y))
130-
131-
_final_isless(c::Int) = c === 1
132-
_final_isless(::StaticInt{N}) where {N} = static(false)
133-
_final_isless(::StaticInt{1}) = static(true)
134-
_isless(c::C, x::Tuple{}, y::Tuple{}) where {C} = _final_isless(c)
135-
function _isless(c::C, x::Tuple, y::Tuple) where {C}
136-
return _isless(icmp(c, x, y), Base.front(x), Base.front(y))
137-
end
138-
icmp(::StaticInt{0}, x::Tuple, y::Tuple) = icmp(last(x), last(y))
139-
icmp(::StaticInt{N}, x::Tuple, y::Tuple) where {N} = static(N)
140-
function icmp(cmp::Int, x::Tuple, y::Tuple)
141-
if cmp === 0
142-
return icmp(Int(last(x)), Int(last(y)))
143-
else
144-
return cmp
145-
end
146-
end
147-
icmp(a, b) = _icmp(lt(a, b), a, b)
148-
_icmp(::True, a, b) = static(1)
149-
_icmp(::False, a, b) = __icmp(Static.eq(a, b))
150-
function _icmp(x::Bool, a, b)
151-
if x
152-
return 1
153-
else
154-
return __icmp(a == b)
155-
end
156-
end
157-
__icmp(::True) = static(0)
158-
__icmp(::False) = static(-1)
159-
function __icmp(x::Bool)
160-
if x
161-
return 0
162-
else
163-
return -1
164-
end
165-
end
166-
167-
# Necessary for compatibility with Base
168-
# In simple cases, we know that we don't need to use axes(A). Optimize those
169-
# until Julia gets smart enough to elide the call on its own:
170-
@inline function Base.to_indices(A, inds, I::Tuple{NDIndex, Vararg{Any}})
171-
return Base.to_indices(A, inds, (Tuple(I[1])..., tail(I)...))
172-
end
173-
# But for arrays of CartesianIndex, we just skip the appropriate number of inds
174-
@inline function Base.to_indices(A, inds, I::Tuple{AbstractArray{NDIndex{N}}, Vararg{Any}}) where N
175-
_, indstail = IteratorsMD.split(inds, Val(N))
176-
return (Base.to_index(A, I[1]), Base.to_indices(A, indstail, tail(I))...)
177-
end
178-

test/ndindex.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
@testset "NDIndex" begin
44

5-
x = NDIndex((1,2,3))
5+
x =
66
y = NDIndex((1,static(2),3))
77
z = NDIndex(static(3), static(3), static(3))
88

@@ -18,7 +18,6 @@ z = NDIndex(static(3), static(3), static(3))
1818
@test @inferred(NDIndex{3,Tuple{Int,Int,Int}}((1,2, 3))) === x
1919
end
2020

21-
@test @inferred(ArrayInterface.known_length(x)) === 3
2221
@test @inferred(length(x)) === 3
2322
@test @inferred(length(typeof(x))) === 3
2423
@test @inferred(y[2]) === 2

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ArrayInterface, Test
22
using Base: setindex
33
using IfElse
4-
using ArrayInterface: StaticInt, True, False
4+
using ArrayInterface: StaticInt, True, False, NDIndex
55
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
66
device, contiguous_axis, contiguous_batch_size, stride_rank, dense_dims, static, NDIndex,
77
is_lazy_conjugate
@@ -17,6 +17,7 @@ using StaticArrays
1717
@test isone(ArrayInterface.known_first(typeof(StaticArrays.SOneTo(7))))
1818
@test ArrayInterface.known_last(typeof(StaticArrays.SOneTo(7))) == 7
1919
@test ArrayInterface.known_length(typeof(StaticArrays.SOneTo(7))) == 7
20+
@test @inferred(ArrayInterface.known_length(NDIndex((1,2,3)))) === 3
2021

2122
using LinearAlgebra, SparseArrays
2223

@@ -899,7 +900,6 @@ end
899900
@test @inferred(ArrayInterface.axes(Array{Float64}(undef, 4, 3)')) === (Base.OneTo(3),Base.OneTo(4))
900901
end
901902

902-
include("ndindex.jl")
903903
include("indexing.jl")
904904
include("dimensions.jl")
905905

0 commit comments

Comments
 (0)