Skip to content

Commit 81c4a3f

Browse files
authored
We don't need to worry about not having reshaped ReinterpretArray now that LTS is v1.6 (#292)
1 parent 36ca34c commit 81c4a3f

File tree

6 files changed

+52
-56
lines changed

6 files changed

+52
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.9"
3+
version = "6.0.10"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/ArrayInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buff
55
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
66
issingular, isstructured, matrix_colors, restructure, lu_instance,
77
safevec, zeromatrix, ColoringAlgorithm,
8-
fast_scalar_indexing, parameterless_type, _is_reshaped, ndims_index, is_splat_index
8+
fast_scalar_indexing, parameterless_type, ndims_index, is_splat_index
99

1010
# ArrayIndex subtypes and methods
1111
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex

src/axes.jl

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -136,39 +136,37 @@ end
136136
end
137137
end
138138

139-
if isdefined(Base, :ReshapedReinterpretArray)
140-
function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}}
141-
if sizeof(S) > sizeof(T)
142-
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))
143-
elseif sizeof(S) < sizeof(T)
144-
P = parent_type(A)
145-
return eachop_tuple(field_type, tail(nstatic(Val(ndims(P)))), axes_types(P))
146-
else
147-
return axes_types(parent_type(A))
148-
end
139+
function axes_types(::Type{A}) where {T,N,S,A<:Base.ReshapedReinterpretArray{T,N,S}}
140+
if sizeof(S) > sizeof(T)
141+
return merge_tuple_type(Tuple{SOneTo{div(sizeof(S), sizeof(T))}}, axes_types(parent_type(A)))
142+
elseif sizeof(S) < sizeof(T)
143+
P = parent_type(A)
144+
return eachop_tuple(field_type, tail(nstatic(Val(ndims(P)))), axes_types(P))
145+
else
146+
return axes_types(parent_type(A))
149147
end
150-
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
151-
if sizeof(S) > sizeof(T)
152-
return (SOneTo(div(sizeof(S), sizeof(T))), axes(parent(A))...)
153-
elseif sizeof(S) < sizeof(T)
154-
return tail(axes(parent(A)))
155-
else
156-
return axes(parent(A))
157-
end
148+
end
149+
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}) where {T,N,S}
150+
if sizeof(S) > sizeof(T)
151+
return (SOneTo(div(sizeof(S), sizeof(T))), axes(parent(A))...)
152+
elseif sizeof(S) < sizeof(T)
153+
return tail(axes(parent(A)))
154+
else
155+
return axes(parent(A))
158156
end
159-
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}, dim) where {T,N,S}
160-
d = to_dims(A, dim)
161-
if sizeof(S) > sizeof(T)
162-
if d == 1
163-
return SOneTo(div(sizeof(S), sizeof(T)))
164-
else
165-
return axes(parent(A), d - static(1))
166-
end
167-
elseif sizeof(S) < sizeof(T)
168-
return axes(parent(A), d - static(1))
157+
end
158+
@inline function axes(A::Base.ReshapedReinterpretArray{T,N,S}, dim) where {T,N,S}
159+
d = to_dims(A, dim)
160+
if sizeof(S) > sizeof(T)
161+
if d == 1
162+
return SOneTo(div(sizeof(S), sizeof(T)))
169163
else
170-
return axes(parent(A), d)
164+
return axes(parent(A), d - static(1))
171165
end
166+
elseif sizeof(S) < sizeof(T)
167+
return axes(parent(A), d - static(1))
168+
else
169+
return axes(parent(A), d)
172170
end
173171
end
174172

src/dimensions.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I
6262
out
6363
end
6464
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))
65-
function from_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
66-
if !_is_reshaped(R) || sizeof(S) === sizeof(T)
65+
function from_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
66+
if !IsReshaped || sizeof(S) === sizeof(T)
6767
return nstatic(Val(ndims(A)))
6868
elseif sizeof(S) > sizeof(T)
6969
return tail(nstatic(Val(ndims(A) + 1)))
@@ -115,9 +115,9 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
115115
end
116116
out
117117
end
118-
function to_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
118+
function to_parent_dims(::Type{<:ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
119119
pdims = nstatic(Val(ndims(A)))
120-
if !_is_reshaped(R) || sizeof(S) === sizeof(T)
120+
if !IsReshaped || sizeof(S) === sizeof(T)
121121
return pdims
122122
elseif sizeof(S) > sizeof(T)
123123
return (Zero(), pdims...,)

src/size.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ _sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = length(getfield(x, dim))
2828
@inline function size(B::PermutedDimsArray{T,N,I1}) where {T,N,I1}
2929
permute(size(parent(B)), static(I1))
3030
end
31-
function size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A}
31+
function size(a::ReinterpretArray{T,N,S,A,IsReshaped}) where {T,N,S,A,IsReshaped}
3232
psize = size(parent(a))
33-
if _is_reshaped(typeof(a))
33+
if IsReshaped
3434
if sizeof(S) === sizeof(T)
3535
return psize
3636
elseif sizeof(S) > sizeof(T)

test/dimensions.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,21 @@
3636
@test_throws DimensionMismatch ArrayInterface.from_parent_dims(typeof(vadj), 0)
3737
@test_throws DimensionMismatch ArrayInterface.from_parent_dims(typeof(vadj), static(0))
3838

39-
if VERSION v"1.6.0-DEV.1581"
40-
colormat = reinterpret(reshape, Float64, [(R=rand(), G=rand(), B=rand()) for i 1:100])
41-
@test @inferred(ArrayInterface.from_parent_dims(typeof(colormat))) === (static(2),)
42-
@test @inferred(ArrayInterface.to_parent_dims(typeof(colormat))) === (static(0), static(1),)
43-
44-
Rr = reinterpret(reshape, Int32, ones(4))
45-
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(2),)
46-
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(0), static(1),)
47-
48-
Rr = reinterpret(reshape, Int64, ones(4))
49-
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(1),)
50-
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(1),)
51-
52-
Sr = reinterpret(reshape, Complex{Int64}, zeros(2, 3, 4))
53-
@test @inferred(ArrayInterface.from_parent_dims(typeof(Sr))) === (static(0), static(1), static(2))
54-
@test @inferred(ArrayInterface.to_parent_dims(typeof(Sr))) === (static(2), static(3))
55-
end
39+
colormat = reinterpret(reshape, Float64, [(R=rand(), G=rand(), B=rand()) for i 1:100])
40+
@test @inferred(ArrayInterface.from_parent_dims(typeof(colormat))) === (static(2),)
41+
@test @inferred(ArrayInterface.to_parent_dims(typeof(colormat))) === (static(0), static(1),)
42+
43+
Rr = reinterpret(reshape, Int32, ones(4))
44+
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(2),)
45+
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(0), static(1),)
46+
47+
Rr = reinterpret(reshape, Int64, ones(4))
48+
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(1),)
49+
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(1),)
50+
51+
Sr = reinterpret(reshape, Complex{Int64}, zeros(2, 3, 4))
52+
@test @inferred(ArrayInterface.from_parent_dims(typeof(Sr))) === (static(0), static(1), static(2))
53+
@test @inferred(ArrayInterface.to_parent_dims(typeof(Sr))) === (static(2), static(3))
5654
end
5755

5856
@testset "order_named_inds" begin
@@ -139,7 +137,7 @@ end
139137
x[x=1] = [2, 3]
140138
@test @inferred(getindex(x, x=1)) == [2, 3]
141139
y = NamedDimsWrapper((:x, static(:y)), ones(2, 2))
142-
# FIXME this doesn't correctly infer the output because it can't infer
140+
# FIXME this doesn't correctly infer the output because it can't infer
143141
@test getindex(y, x=1) == [1, 1]
144142
end
145143

@@ -165,4 +163,4 @@ end
165163
@test @inferred(ArrayInterface.dense_dims(u_view)) == (False(),)
166164
@test @inferred(ArrayInterface.dense_dims(u_reshaped_view1)) == (False(), False())
167165
@test @inferred(ArrayInterface.dense_dims(u_reshaped_view2)) == (False(), False())
168-
end
166+
end

0 commit comments

Comments
 (0)