Skip to content

Commit 6e37ea4

Browse files
committed
Replace nstatic with new ntuple support
1 parent 6da0033 commit 6e37ea4

File tree

7 files changed

+32
-29
lines changed

7 files changed

+32
-29
lines changed

lib/ArrayInterfaceOffsetArrays/src/ArrayInterfaceOffsetArrays.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
2525
OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
2626
end
2727
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
28-
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterface.parent_type(T))
28+
Static.eachop_tuple(
29+
_offset_axis_type,
30+
ntuple(static, StaticInt(ndims(T))),
31+
ArrayInterface.parent_type(T)
32+
)
2933
end
3034
ArrayInterface.strides(A::OffsetArray) = ArrayInterface.strides(parent(A))
3135
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}

lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ ArrayInterface.device(::Type{<:StaticArrays.MArray}) = ArrayInterface.CPUPointer
4040
ArrayInterface.device(::Type{<:StaticArrays.SArray}) = ArrayInterface.CPUTuple()
4141
ArrayInterface.contiguous_axis(::Type{<:StaticArrays.StaticArray}) = StaticInt{1}()
4242
ArrayInterface.contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = StaticInt{0}()
43-
ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}} = Static.nstatic(Val(N))
43+
function ArrayInterface.stride_rank(::Type{T}) where {N,T<:StaticArray{<:Any,<:Any,N}}
44+
ntuple(static, StaticInt(N))
45+
end
4446
function ArrayInterface.dense_dims(::Type{<:StaticArray{S,T,N}}) where {S,T,N}
4547
ArrayInterface._all_dense(Val(N))
4648
end

src/axes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565
end
6666

6767
function axes_types(::Type{T}) where {T<:ReinterpretArray}
68-
eachop_tuple(_non_reshaped_axis_type, nstatic(Val(ndims(T))), T)
68+
eachop_tuple(_non_reshaped_axis_type, ntuple(static, StaticInt(ndims(T))), T)
6969
end
7070

7171
function _non_reshaped_axis_type(::Type{A}, d::StaticInt{D}) where {A,D}
@@ -146,7 +146,7 @@ function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N
146146
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))
147147
elseif sizeof(S) < sizeof(T)
148148
P = parent_type(A)
149-
return eachop_tuple(field_type, tail(nstatic(Val(ndims(P)))), axes_types(P))
149+
return eachop_tuple(field_type, tail(ntuple(static, StaticInt(ndims(P)))), axes_types(P))
150150
else
151151
return axes_types(parent_type(A))
152152
end

src/dimensions.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ function throw_dim_error(@nospecialize(x), @nospecialize(dim))
33
throw(DimensionMismatch("$x does not have dimension corresponding to $dim"))
44
end
55

6-
#julia> @btime ArrayInterfaceCore.is_increasing(ArrayInterfaceCore.nstatic(Val(10)))
7-
# 0.045 ns (0 allocations: 0 bytes)
8-
#ArrayInterfaceCore.True()
96
function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y},Vararg}) where {X, Y}
107
if X <= Y
118
return is_increasing(tail(perm))
@@ -30,7 +27,7 @@ is_increasing(::Tuple{}) = True()
3027
Returns the mapping from parent dimensions to child dimensions.
3128
"""
3229
from_parent_dims(x) = from_parent_dims(typeof(x))
33-
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
30+
from_parent_dims(::Type{T}) where {T} = ntuple(static, StaticInt(ndims(T)))
3431
from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),)
3532
from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One())
3633
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I)
@@ -51,11 +48,11 @@ end
5148
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))
5249
function from_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
5350
if !IsReshaped || sizeof(S) === sizeof(T)
54-
return nstatic(Val(ndims(A)))
51+
return ntuple(static, StaticInt(ndims(A)))
5552
elseif sizeof(S) > sizeof(T)
56-
return tail(nstatic(Val(ndims(A) + 1)))
53+
return tail(ntuple(static, StaticInt(ndims(A) + 1)))
5754
else # sizeof(S) < sizeof(T)
58-
return (Zero(), nstatic(Val(N))...)
55+
return (Zero(), ntuple(static, StaticInt(N))...)
5956
end
6057
end
6158

@@ -86,7 +83,7 @@ end
8683
Returns the mapping from child dimensions to parent dimensions.
8784
"""
8885
to_parent_dims(x) = to_parent_dims(typeof(x))
89-
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
86+
to_parent_dims(::Type{T}) where {T} = ntuple(static, StaticInt(ndims(T)))
9087
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
9188
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = static(Val(I))
9289
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
@@ -103,7 +100,7 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
103100
out
104101
end
105102
function to_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
106-
pdims = nstatic(Val(ndims(A)))
103+
pdims = ntuple(static, StaticInt(ndims(A)))
107104
if !IsReshaped || sizeof(S) === sizeof(T)
108105
return pdims
109106
elseif sizeof(S) > sizeof(T)
@@ -266,7 +263,7 @@ to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
266263
to_dims(x, dim::Integer) = Int(dim)
267264
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
268265
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
269-
eachop(_to_dims, nstatic(Val(N)), dimnames(x), dims)
266+
eachop(_to_dims, ntuple(static, StaticInt(N)), dimnames(x), dims)
270267
end
271268
@inline _to_dims(x::Tuple, d::Tuple, n::StaticInt{N}) where {N} = _to_dim(x, getfield(d, N))
272269
@inline function _to_dim(x::Tuple, d::Union{Symbol,StaticSymbol})

src/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ to_indices(A, ::Tuple{}) = ()
9696
inds,
9797
IndexStyle(A),
9898
static(ndims(A)),
99-
eachop(_ndims_index, nstatic(Val(known_length(I))), I),
100-
eachop(_is_splat, nstatic(Val(known_length(I))), I)
99+
eachop(_ndims_index, ntuple(static, StaticInt(known_length(I))), I),
100+
eachop(_is_splat, ntuple(static, StaticInt(known_length(I))), I)
101101
)
102102
end
103103
@generated function _to_indices(A, inds::I, ::S, ::StaticInt{N}, ::NDI, ::IS) where {I,S,N,NDI,IS}

src/size.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ size(x::Iterators.Enumerate) = size(getfield(x, :itr))
5656
size(x::Iterators.Accumulate) = size(getfield(x, :itr))
5757
size(x::Iterators.Pairs) = size(getfield(x, :itr))
5858
@inline function size(x::Iterators.ProductIterator)
59-
eachop(_sub_size, nstatic(Val(ndims(x))), getfield(x, :iterators))
59+
eachop(_sub_size, ntuple(static, StaticInt(ndims(x))), getfield(x, :iterators))
6060
end
6161

6262
size(a, dim) = size(a, to_dims(a, dim))
@@ -100,7 +100,7 @@ known_size(x) = known_size(typeof(x))
100100
end
101101
end
102102
function _maybe_known_size(::Base.HasShape{N}, ::Type{T}) where {N,T}
103-
eachop(_known_size, nstatic(Val(N)), axes_types(T))
103+
eachop(_known_size, ntuple(static, StaticInt(N)), axes_types(T))
104104
end
105105
_maybe_known_size(::Base.IteratorSize, ::Type{T}) where {T} = (known_length(T),)
106106
function known_size(::Type{T}) where {T<:AbstractRange}
@@ -113,7 +113,7 @@ known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
113113
known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I)
114114
known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I)
115115
@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T}
116-
eachop(_known_size, nstatic(Val(known_length(T))), T)
116+
eachop(_known_size, ntuple(static, StaticInt(known_length(T))), T)
117117
end
118118

119119
# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,
@@ -123,7 +123,7 @@ end
123123
# trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to
124124
# `Nothing` and `Int` types, so we do one last pass to ensure everything is dynamic
125125
@inline function known_size(::Type{<:Iterators.Zip{T}}) where {T}
126-
dynamic(reduce_tup(Static._promote_shape, eachop(_unzip_size, nstatic(Val(known_length(T))), T)))
126+
dynamic(reduce_tup(Static._promote_shape, eachop(_unzip_size, ntuple(static, StaticInt(known_length(T))), T)))
127127
end
128128
_unzip_size(::Type{T}, n::StaticInt{N}) where {T,N} = known_size(field_type(T, n))
129129
_known_size(::Type{T}, dim::StaticInt) where {T} = known_length(field_type(T, dim))

src/stridelayout.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ stride_preserving_index(::Type{T}) where {T<:AbstractRange} = True()
2525
stride_preserving_index(::Type{T}) where {T<:Int} = True()
2626
stride_preserving_index(::Type{T}) where {T} = False()
2727
function stride_preserving_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
28-
if all(eachop(_stride_preserving_index, nstatic(Val(N)), T))
28+
if all(eachop(_stride_preserving_index, ntuple(static, StaticInt(N)), T))
2929
return True()
3030
else
3131
return False()
@@ -54,7 +54,7 @@ end
5454

5555
known_offsets(x) = known_offsets(typeof(x))
5656
function known_offsets(::Type{T}) where {T}
57-
return eachop(_known_offsets, nstatic(Val(ndims(T))), axes_types(T))
57+
return eachop(_known_offsets, ntuple(static, StaticInt(ndims(T))), axes_types(T))
5858
end
5959
_known_offsets(::Type{T}, dim::StaticInt) where {T} = known_first(field_type(T, dim))
6060

@@ -71,7 +71,7 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1)
7171
@inline offsets(x, i) = static_first(indices(x, i))
7272
offsets(::Tuple) = (One(),)
7373
offsets(x::StrideIndex) = getfield(x, :offsets)
74-
offsets(x) = eachop(_offsets, nstatic(Val(ndims(x))), x)
74+
offsets(x) = eachop(_offsets, ntuple(static, StaticInt(ndims(x))), x)
7575
function _offsets(x::X, dim::StaticInt{D}) where {X,D}
7676
start = known_first(axes_types(X, dim))
7777
if start === nothing
@@ -216,7 +216,7 @@ end
216216
contiguous_axis_indicator(::A) where {A<:AbstractArray} = contiguous_axis_indicator(A)
217217
contiguous_axis_indicator(::Nothing, ::Val) = nothing
218218
function contiguous_axis_indicator(c::StaticInt{N}, dim::Val{D}) where {N,D}
219-
return map(i -> eq(c, i), nstatic(dim))
219+
map(i -> eq(c, i), ntuple(static, dim))
220220
end
221221

222222
function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N}
@@ -233,8 +233,8 @@ stride_rank(x) = stride_rank(typeof(x))
233233
function stride_rank(::Type{T}) where {T}
234234
is_forwarding_wrapper(T) ? stride_rank(parent_type(T)) : nothing
235235
end
236-
stride_rank(::Type{<:DenseArray{T,N}}) where {T,N} = nstatic(Val(N))
237-
stride_rank(::Type{BitArray{N}}) where {N} = nstatic(Val(N))
236+
stride_rank(::Type{<:DenseArray{T,N}}) where {T,N} = ntuple(static, StaticInt(N))
237+
stride_rank(::Type{BitArray{N}}) where {N} = ntuple(static, StaticInt(N))
238238
stride_rank(::Type{<:AbstractRange}) = (One(),)
239239
stride_rank(::Type{<:Tuple}) = (One(),)
240240

@@ -257,7 +257,7 @@ _stride_rank(::Type{T}, r::Tuple) where {T<:SubArray} = permute(r, to_parent_dim
257257

258258
stride_rank(x, i) = stride_rank(x)[i]
259259
function stride_rank(::Type{R}) where {T,N,S,A<:Array{S},R<:Base.ReinterpretArray{T,N,S,A}}
260-
return nstatic(Val(N))
260+
return ntuple(static, StaticInt(N))
261261
end
262262
@inline function stride_rank(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}
263263
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
@@ -304,7 +304,7 @@ function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T,
304304
IfElse.ifelse(is_dense(A), (static(1),), nothing)
305305
end
306306

307-
_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N))
307+
_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = ntuple(static, StaticInt(N))
308308
_reshaped_striderank(_, __, ___) = nothing
309309

310310
"""
@@ -466,7 +466,7 @@ function dense_dims(T::Type{<:Base.ReshapedArray})
466466
return n_of_x(StaticInt(ndims(T)), False())
467467
end
468468
end
469-
469+
470470
is_dense(A) = is_dense(typeof(A))
471471
is_dense(::Type{A}) where {A} = _is_dense(dense_dims(A))
472472
_is_dense(::Tuple{False,Vararg}) = False()

0 commit comments

Comments
 (0)