@@ -16,7 +16,7 @@ import ArrayInterface: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
16
16
17
17
# ArrayIndex subtypes and methods
18
18
import ArrayInterface: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex,
19
- TridiagonalIndex
19
+ TridiagonalIndex, StrideIndex
20
20
# managing immutables
21
21
import ArrayInterface: ismutable, can_change_size, can_setindex
22
22
# constants
@@ -27,6 +27,11 @@ import ArrayInterface: AbstractDevice, AbstractCPU, CPUPointer, CPUTuple, CheckP
27
27
28
28
import ArrayInterface: known_first, known_step, known_last
29
29
30
+ import ArrayInterface: offsets, axes_types, offset1, indices, known_offsets, stride_rank, strides, dense_dims, contiguous_axis, known_length, contiguous_batch_size,
31
+ contiguous_axis_indicator, is_column_major, _all_dense, AbstractArray2, dimnames, known_dimnames, known_offset1, known_strides, LazyAxis, lazy_axes, BroadcastAxis, broadcast_axis,
32
+ to_axes, find_all_dimnames, to_dims, known_size, deleteat, insert, static_size, is_lazy_conjugate, static_stride, static_strides, has_dimnames, unsafe_reconstruct, static_to_indices,
33
+ to_index, unsafe_setindex!, unsafe_getindex, static_axes, to_axis, is_dense, static_getindex
34
+
30
35
using Static
31
36
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
32
37
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
@@ -52,24 +57,22 @@ _sub1(@nospecialize x) = x - oneunit(x)
52
57
Tuple{X. parameters... , Y. parameters... }
53
58
end
54
59
55
- abstract type AbstractArray2{T, N} <: AbstractArray{T, N} end
56
-
57
- Base. size (A:: AbstractArray2 ) = map (Int, ArrayInterface. size (A))
58
- Base. size (A:: AbstractArray2 , dim) = Int (ArrayInterface. size (A, dim))
60
+ Base. size (A:: AbstractArray2 ) = map (Int, ArrayInterface. static_size (A))
61
+ Base. size (A:: AbstractArray2 , dim) = Int (ArrayInterface. static_size (A, dim))
59
62
60
63
function Base. axes (A:: AbstractArray2 )
61
- is_forwarding_wrapper (A) && return ArrayInterface. axes (parent (A))
64
+ is_forwarding_wrapper (A) && return ArrayInterface. static_axes (parent (A))
62
65
throw (ArgumentError (" Subtypes of `AbstractArray2` must define an axes method" ))
63
66
end
64
67
function Base. axes (A:: AbstractArray2 , dim:: Union{Symbol, StaticSymbol} )
65
- axes (A, to_dims (A, dim))
68
+ static_axes (A, to_dims (A, dim))
66
69
end
67
70
68
71
function Base. strides (A:: AbstractArray2 )
69
- defines_strides (A) && return map (Int, ArrayInterface. strides (A))
72
+ defines_strides (A) && return map (Int, ArrayInterface. static_strides (A))
70
73
throw (MethodError (Base. strides, (A,)))
71
74
end
72
- Base. strides (A:: AbstractArray2 , dim) = Int (ArrayInterface. strides (A, dim))
75
+ Base. strides (A:: AbstractArray2 , dim) = Int (ArrayInterface. static_strides (A, dim))
73
76
74
77
function Base. IndexStyle (:: Type{T} ) where {T <: AbstractArray2 }
75
78
is_forwarding_wrapper (T) ? IndexStyle (parent_type (T)) : IndexCartesian ()
78
81
function Base. length (A:: AbstractArray2 )
79
82
len = known_length (A)
80
83
if len === nothing
81
- return Int (prod (size (A)))
84
+ return Int (prod (static_size (A)))
82
85
else
83
86
return Int (len)
84
87
end
85
88
end
86
89
87
- @propagate_inbounds Base. getindex (A:: AbstractArray2 , args... ) = getindex (A, args... )
88
- @propagate_inbounds Base. getindex (A:: AbstractArray2 ; kwargs... ) = getindex (A; kwargs... )
90
+ @propagate_inbounds Base. getindex (A:: AbstractArray2 , args... ) = static_getindex (A, args... )
91
+ @propagate_inbounds Base. getindex (A:: AbstractArray2 ; kwargs... ) = static_getindex (A; kwargs... )
89
92
90
93
@propagate_inbounds function Base. setindex! (A:: AbstractArray2 , val, args... )
91
94
return setindex! (A, val, args... )
102
105
@inbounds (CartesianIndices (ntuple (dim -> indices (a, dim), Val (ndims (a))))[i])
103
106
end
104
107
@inline function _to_linear (a, i:: Tuple{IntType, Vararg{IntType}} )
105
- _strides2int (offsets (a), size_to_strides (size (a), static (1 )), i) + static (1 )
108
+ _strides2int (offsets (a), size_to_strides (static_size (a), static (1 )), i) + static (1 )
106
109
end
107
110
108
111
"""
@@ -167,7 +170,7 @@ Returns a new instance of `collection` with `item` inserted into at the given `i
167
170
"""
168
171
Base. @propagate_inbounds function insert (collection, index, item)
169
172
@boundscheck checkbounds (collection, index)
170
- ret = similar (collection, length (collection) + 1 )
173
+ ret = similar (collection, static_length (collection) + 1 )
171
174
@inbounds for i in firstindex (ret): (index - 1 )
172
175
ret[i] = collection[i]
173
176
end
@@ -213,7 +216,7 @@ Base.@propagate_inbounds function deleteat(collection::Tuple{Vararg{Any, N}},
213
216
end
214
217
215
218
function unsafe_deleteat (src:: AbstractVector , index)
216
- dst = similar (src, length (src) - 1 )
219
+ dst = similar (src, static_length (src) - 1 )
217
220
@inbounds for i in indices (dst)
218
221
if i < index
219
222
dst[i] = src[i]
@@ -225,7 +228,7 @@ function unsafe_deleteat(src::AbstractVector, index)
225
228
end
226
229
227
230
@inline function unsafe_deleteat (src:: AbstractVector , inds:: AbstractVector )
228
- dst = similar (src, length (src) - length (inds))
231
+ dst = similar (src, static_length (src) - static_length (inds))
229
232
dst_index = firstindex (dst)
230
233
@inbounds for src_index in indices (src)
231
234
if ! in (src_index, inds)
237
240
end
238
241
239
242
@inline function unsafe_deleteat (src:: Tuple , inds:: AbstractVector )
240
- dst = Vector {eltype(src)} (undef, length (src) - length (inds))
243
+ dst = Vector {eltype(src)} (undef, static_length (src) - static_length (inds))
241
244
dst_index = firstindex (dst)
242
- @inbounds for src_index in static (1 ): length (src)
245
+ @inbounds for src_index in static (1 ): static_length (src)
243
246
if ! in (src_index, inds)
244
247
dst[dst_index] = src[src_index]
245
248
dst_index += one (dst_index)
253
256
@inline function unsafe_deleteat (x:: Tuple , i)
254
257
if i === one (i)
255
258
return tail (x)
256
- elseif i == length (x)
259
+ elseif i == static_length (x)
257
260
return Base. front (x)
258
261
else
259
262
return (first (x), unsafe_deleteat (tail (x), i - one (i))... )
0 commit comments