Skip to content

Commit a2ea55b

Browse files
authored
Avoid ReshapedArray error when contiguous_axis is nothing and fix dense_dims (#307)
1 parent 4245ffe commit a2ea55b

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.15"
3+
version = "6.0.16"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/stridelayout.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ function contiguous_axis(::Type{T}) where {T<:PermutedDimsArray}
164164
end
165165
end
166166
function contiguous_axis(::Type{<:Base.ReshapedArray{T, N, A, Tuple{}}}) where {T, N, A}
167-
if isone(-contiguous_axis(A))
167+
c = contiguous_axis(A)
168+
if c !== nothing && isone(-c)
168169
return StaticInt(-1)
169170
elseif dynamic(is_column_major(A) & is_dense(A))
170171
return StaticInt(1)
@@ -455,9 +456,17 @@ _dense_dims(::Type{S}, ::Nothing, ::Val{R}) where {R,N,NP,T,A<:AbstractArray{T,N
455456
end
456457
end
457458

458-
function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
459-
return _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}())
459+
function dense_dims(T::Type{<:Base.ReshapedArray})
460+
d = dense_dims(parent_type(T))
461+
if d === nothing
462+
return nothing
463+
elseif all(d)
464+
return n_of_x(StaticInt(ndims(T)), True())
465+
else
466+
return n_of_x(StaticInt(ndims(T)), False())
467+
end
460468
end
469+
461470
is_dense(A) = is_dense(typeof(A))
462471
is_dense(::Type{A}) where {A} = _is_dense(dense_dims(A))
463472
_is_dense(::Tuple{False,Vararg}) = False()
@@ -466,19 +475,6 @@ _is_dense(t::Tuple{True}) = True()
466475
_is_dense(t::Tuple{}) = True()
467476
_is_dense(::Nothing) = False()
468477

469-
470-
_reshaped_dense_dims(_, __, ___, ____) = nothing
471-
function _reshaped_dense_dims(dense::Tuple, ::True, ::Val{N}, ::Val{0}) where {N}
472-
if all(dense)
473-
return _all_dense(Val{N}())
474-
else
475-
return nothing
476-
end
477-
end
478-
function _reshaped_dense_dims(dense::Tuple{Static.False}, ::True, ::Val{N}, ::Val{0}) where {N}
479-
return return ntuple(_ -> False(), Val{N}())
480-
end
481-
482478
"""
483479
known_strides(::Type{T}) -> Tuple
484480
known_strides(::Type{T}, dim) -> Union{Int,Nothing}

test/stridelayout.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ end
160160
@test @inferred(ArrayInterface.contiguous_axis((3,4))) === StaticInt(1)
161161
@test @inferred(ArrayInterface.contiguous_axis(rand(4)')) === StaticInt(2)
162162
@test @inferred(ArrayInterface.contiguous_axis(view(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])', :, 1)')) === StaticInt(-1)
163+
@test @inferred(ArrayInterface.contiguous_axis(reshape(DummyZeros(3,4), (4, 3)))) === nothing
163164
@test @inferred(ArrayInterface.contiguous_axis(DummyZeros(3,4))) === nothing
164165
@test @inferred(ArrayInterface.contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
165166
@test @inferred(ArrayInterface.contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
@@ -259,6 +260,7 @@ end
259260
@test @inferred(ArrayInterface.dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == (true,false)
260261
@test @inferred(ArrayInterface.dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) == (false,true,false)
261262
@test @inferred(ArrayInterface.dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) == (false,false,false)
263+
@test @inferred(ArrayInterface.dense_dims(reshape(view(randn(10, 10, 10), 3, :, :), 1, 100))) == (false, false)
262264
# TODO Currently Wrapper can't function the same as Array because Array can change
263265
# the dimensions on reshape. We should be rewrapping the result in `Wrapper` but we
264266
# first need to develop a standard method for reconstructing arrays

0 commit comments

Comments
 (0)