|
1 | 1 |
|
2 |
| -""" |
3 |
| - NDIndex(i, j, k...) -> I |
4 |
| - NDIndex((i, j, k...)) -> I |
5 |
| -
|
6 |
| -A multidimensional index that refers to a single element. Each dimension is represented by |
7 |
| -a single `Int` or `StaticInt`. |
8 |
| -
|
9 |
| -```julia |
10 |
| -julia> using ArrayInterface: NDIndex |
11 |
| -
|
12 |
| -julia> using Static |
13 |
| -
|
14 |
| -julia> i = NDIndex(static(1), 2, static(3)) |
15 |
| -NDIndex(static(1), 2, static(3)) |
16 |
| -
|
17 |
| -julia> i[static(1)] |
18 |
| -static(1) |
19 |
| -
|
20 |
| -julia> i[1] |
21 |
| -1 |
22 |
| -
|
23 |
| -``` |
24 |
| -""" |
25 |
| -struct NDIndex{N,I<:Tuple{Vararg{Any,N}}} <: AbstractCartesianIndex{N} |
26 |
| - index::I |
27 |
| - |
28 |
| - global _NDIndex(index::Tuple{Vararg{Any,N}}) where {N} = new{N,typeof(index)}(index) |
29 |
| - |
30 |
| - function NDIndex{N,I}(index::I) where {N,I<:Tuple{Vararg{Integer,N}}} |
31 |
| - for i in index |
32 |
| - (i isa Int) || i isa StaticInt || throw(ArgumentError("NDIndex does not support values of type $(typeof(i))")) |
33 |
| - end |
34 |
| - return new{N,I}(index) |
35 |
| - end |
36 |
| - |
37 |
| - NDIndex{N}(index::Tuple) where {N} = _ndindex(static(N), _flatten(index...)) |
38 |
| - NDIndex{N}(index...) where {N} = _ndindex(static(N), _flatten(index...)) |
39 |
| - |
40 |
| - NDIndex{0}(::Tuple{}) = new{0,Tuple{}}(()) |
41 |
| - NDIndex{0}() = NDIndex{0}(()) |
42 |
| - |
43 |
| - NDIndex(index::Tuple) = _NDIndex(_flatten(index...)) |
44 |
| - NDIndex(index...) = _NDIndex(_flatten(index...)) |
45 |
| -end |
46 |
| - |
47 |
| -_ndindex(n::StaticInt{N}, i::Tuple{Vararg{Union{Int,StaticInt},N}}) where {N} = _NDIndex(i) |
48 |
| -function _ndindex(n::StaticInt{N}, i::Tuple{Vararg{Any,M}}) where {N,M} |
49 |
| - M > N && throw(ArgumentError("input tuple of length $M, requested $N")) |
50 |
| - return _NDIndex(_fill_to_length(i, n)) |
51 |
| -end |
52 |
| -_fill_to_length(x::Tuple{Vararg{Any,N}}, n::StaticInt{N}) where {N} = x |
53 |
| -@inline function _fill_to_length(x::Tuple{Vararg{Any,M}}, n::StaticInt{N}) where {M,N} |
54 |
| - return _fill_to_length((x..., static(1)), n) |
55 |
| -end |
56 |
| - |
57 |
| -_flatten(i::StaticInt{N}) where {N} = (i,) |
58 |
| -_flatten(i::Integer) = (Int(i),) |
59 |
| -_flatten(i::Base.AbstractCartesianIndex) = _flatten(Tuple(i)...) |
60 |
| -@inline _flatten(i::Integer, I...) = (canonicalize(i), _flatten(I...)...) |
61 |
| -@inline function _flatten(i::Base.AbstractCartesianIndex, I...) |
62 |
| - return (_flatten(Tuple(i)...)..., _flatten(I...)...) |
63 |
| - end |
64 |
| -Base.Tuple(index::NDIndex) = index.index |
65 |
| - |
66 |
| -Static.dynamic(x::NDIndex) = CartesianIndex(dynamic(Tuple(x))) |
67 |
| -Static.static(x::CartesianIndex) = _NDIndex(static(Tuple(x))) |
68 |
| -Static.known(::Type{NDIndex{N,I}}) where {N,I} = known(I) |
69 |
| - |
70 |
| -Base.show(io::IO, i::NDIndex) = (print(io, "NDIndex"); show(io, Tuple(i))) |
71 |
| - |
72 |
| -# length |
73 |
| -Base.length(::NDIndex{N}) where {N} = N |
74 |
| -Base.length(::Type{NDIndex{N,I}}) where {N,I} = N |
75 |
| -known_length(::Type{NDIndex{N,I}}) where {N,I} = N |
76 |
| - |
77 |
| -# indexing |
78 |
| -@propagate_inbounds function getindex(x::NDIndex{N,T}, i::Int)::Int where {N,T} |
79 |
| - return Int(getfield(Tuple(x), i)) |
80 |
| -end |
81 |
| -@propagate_inbounds function getindex(x::NDIndex{N,T}, i::StaticInt{I}) where {N,T,I} |
82 |
| - return getfield(Tuple(x), I) |
83 |
| -end |
84 |
| -@propagate_inbounds Base.getindex(x::NDIndex, i::Integer) = ArrayInterface.getindex(x, i) |
85 |
| - |
86 |
| -# Base.get(A::AbstractArray, I::CartesianIndex, default) = get(A, I.I, default) |
87 |
| -# eltype(::Type{T}) where {T<:CartesianIndex} = eltype(fieldtype(T, :I)) |
88 |
| - |
89 |
| -Base.setindex(x::NDIndex, i, j) = NDIndex(Base.setindex(Tuple(x), i, j)) |
90 |
| - |
91 |
| -# equality |
92 |
| -Base.:(==)(x::NDIndex{N}, y::NDIndex{N}) where N = Tuple(x) == Tuple(y) |
93 |
| - |
94 |
| -# zeros and ones |
95 |
| -Base.zero(::NDIndex{N}) where {N} = zero(NDIndex{N}) |
96 |
| -Base.zero(::Type{NDIndex{N}}) where {N} = _NDIndex(ntuple(_ -> static(0), Val(N))) |
97 |
| -Base.oneunit(::NDIndex{N}) where {N} = oneunit(NDIndex{N}) |
98 |
| -Base.oneunit(::Type{NDIndex{N}}) where {N} = _NDIndex(ntuple(_ -> static(1), Val(N))) |
99 |
| - |
100 |
| -@inline function Base.split(i::NDIndex, V::Val) |
101 |
| - i, j = split(Tuple(i), V) |
102 |
| - return NDIndex(i), NDIndex(j) |
103 |
| -end |
104 |
| - |
105 |
| -# arithmetic, min/max |
106 |
| -@inline Base.:(-)(i::NDIndex{N}) where {N} = NDIndex{N}(map(-, Tuple(i))) |
107 |
| -@inline function Base.:(+)(i1::NDIndex{N}, i2::NDIndex{N}) where {N} |
108 |
| - return _NDIndex(map(+, Tuple(i1), Tuple(i2))) |
109 |
| -end |
110 |
| -@inline function Base.:(-)(i1::NDIndex{N}, i2::NDIndex{N}) where {N} |
111 |
| - return _NDIndex(map(-, Tuple(i1), Tuple(i2))) |
112 |
| -end |
113 |
| -@inline function Base.min(i1::NDIndex{N}, i2::NDIndex{N}) where {N} |
114 |
| - return _NDIndex(map(min, Tuple(i1), Tuple(i2))) |
115 |
| -end |
116 |
| -@inline function Base.max(i1::NDIndex{N}, i2::NDIndex{N}) where {N} |
117 |
| - return _NDIndex(map(max, Tuple(i1), Tuple(i2))) |
118 |
| -end |
119 |
| -@inline Base.:(*)(a::Integer, i::NDIndex{N}) where {N} = _NDIndex(map(x->a*x, Tuple(i))) |
120 |
| -@inline Base.:(*)(i::NDIndex, a::Integer) = *(a, i) |
121 |
| - |
122 |
| -Base.CartesianIndex(x::NDIndex) = CartesianIndex(Tuple(x)) |
123 |
| - |
124 |
| -# comparison |
125 |
| -@inline function Base.isless(x::NDIndex{N}, y::NDIndex{N}) where {N} |
126 |
| - return Bool(_isless(static(0), Tuple(x), Tuple(y))) |
127 |
| -end |
128 |
| - |
129 |
| -Static.lt(x::NDIndex{N}, y::NDIndex{N}) where {N} = _isless(static(0), Tuple(x), Tuple(y)) |
130 |
| - |
131 |
| -_final_isless(c::Int) = c === 1 |
132 |
| -_final_isless(::StaticInt{N}) where {N} = static(false) |
133 |
| -_final_isless(::StaticInt{1}) = static(true) |
134 |
| -_isless(c::C, x::Tuple{}, y::Tuple{}) where {C} = _final_isless(c) |
135 |
| -function _isless(c::C, x::Tuple, y::Tuple) where {C} |
136 |
| - return _isless(icmp(c, x, y), Base.front(x), Base.front(y)) |
137 |
| -end |
138 |
| -icmp(::StaticInt{0}, x::Tuple, y::Tuple) = icmp(last(x), last(y)) |
139 |
| -icmp(::StaticInt{N}, x::Tuple, y::Tuple) where {N} = static(N) |
140 |
| -function icmp(cmp::Int, x::Tuple, y::Tuple) |
141 |
| - if cmp === 0 |
142 |
| - return icmp(Int(last(x)), Int(last(y))) |
143 |
| - else |
144 |
| - return cmp |
145 |
| - end |
146 |
| -end |
147 |
| -icmp(a, b) = _icmp(lt(a, b), a, b) |
148 |
| -_icmp(::True, a, b) = static(1) |
149 |
| -_icmp(::False, a, b) = __icmp(Static.eq(a, b)) |
150 |
| -function _icmp(x::Bool, a, b) |
151 |
| - if x |
152 |
| - return 1 |
153 |
| - else |
154 |
| - return __icmp(a == b) |
155 |
| - end |
156 |
| -end |
157 |
| -__icmp(::True) = static(0) |
158 |
| -__icmp(::False) = static(-1) |
159 |
| -function __icmp(x::Bool) |
160 |
| - if x |
161 |
| - return 0 |
162 |
| - else |
163 |
| - return -1 |
164 |
| - end |
165 |
| -end |
166 |
| - |
167 |
| -# Necessary for compatibility with Base |
168 |
| -# In simple cases, we know that we don't need to use axes(A). Optimize those |
169 |
| -# until Julia gets smart enough to elide the call on its own: |
170 |
| -@inline function Base.to_indices(A, inds, I::Tuple{NDIndex, Vararg{Any}}) |
171 |
| - return Base.to_indices(A, inds, (Tuple(I[1])..., tail(I)...)) |
172 |
| -end |
173 |
| -# But for arrays of CartesianIndex, we just skip the appropriate number of inds |
174 |
| -@inline function Base.to_indices(A, inds, I::Tuple{AbstractArray{NDIndex{N}}, Vararg{Any}}) where N |
175 |
| - _, indstail = IteratorsMD.split(inds, Val(N)) |
176 |
| - return (Base.to_index(A, I[1]), Base.to_indices(A, indstail, tail(I))...) |
177 |
| -end |
178 |
| - |
0 commit comments