Skip to content

Commit 7bef132

Browse files
authored
Merge pull request #216 from JuliaArrays/ReshapedArrayofadjiontvectors
Add some ReshapedArray of adjoint vector methods
2 parents b5fab6b + db2a4ff commit 7bef132

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/stridelayout.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ function contiguous_axis(::Type{T}) where {T<:PermutedDimsArray}
151151
return from_parent_dims(T, c)
152152
end
153153
end
154+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, A, Tuple{}}}) where {T, A}
155+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
156+
end
157+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Adjoint{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
158+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
159+
end
160+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
161+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
162+
end
154163
function contiguous_axis(::Type{T}) where {T<:SubArray}
155164
return _contiguous_axis(T, contiguous_axis(parent_type(T)))
156165
end
@@ -267,6 +276,16 @@ end
267276
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
268277
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())
269278
end
279+
function stride_rank(::Type{Base.ReshapedArray{T, 1, A, Tuple{}}}) where {T, A}
280+
IfElse.ifelse(is_column_major(A) & is_dense(A), (static(1),), nothing)
281+
end
282+
function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Adjoint{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
283+
IfElse.ifelse(is_dense(A), (static(1),), nothing)
284+
end
285+
function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
286+
IfElse.ifelse(is_dense(A), (static(1),), nothing)
287+
end
288+
270289
_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N))
271290
_reshaped_striderank(_, __, ___) = nothing
272291

@@ -425,6 +444,14 @@ end
425444
function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
426445
return _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}())
427446
end
447+
is_dense(A) = is_dense(typeof(A))
448+
is_dense(::Type{A}) where {A} = _is_dense(dense_dims(A))
449+
_is_dense(::Tuple{False,Vararg}) = False()
450+
_is_dense(t::Tuple{True,Vararg}) = _is_dense(Base.tail(t))
451+
_is_dense(t::Tuple{True}) = True()
452+
_is_dense(t::Tuple{}) = True()
453+
454+
428455
_reshaped_dense_dims(_, __, ___, ____) = nothing
429456
function _reshaped_dense_dims(dense::D, ::True, ::Val{N}, ::Val{0}) where {D,N}
430457
if all(dense)

test/array_index.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ end
1818
@test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing
1919
@test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3)
2020

21+
let v = Float64.(1:10)', v2 = transpose(parent(v))
22+
sv = @view(v[1:5])'
23+
sv2 = @view(v2[1:5])'
24+
@test @inferred(ArrayInterface.StrideIndex(sv)) === @inferred(ArrayInterface.StrideIndex(sv2)) === ArrayInterface.StrideIndex{2, (2, 1), 2}((StaticInt(1), StaticInt(1)), (StaticInt(1), StaticInt(1)))
25+
@test @inferred(ArrayInterface.stride_rank(parent(sv))) === @inferred(ArrayInterface.stride_rank(parent(sv2))) === (StaticInt(1),)
26+
end
27+

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,10 @@ ArrayInterface.parent_type(::Type{DenseWrapper{T,N,P}}) where {T,N,P} = P
422422
@test @inferred(dense_dims(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
423423
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1))) === nothing
424424
@test @inferred(dense_dims(view(DummyZeros(3,4), :, 1)')) === nothing
425+
@test @inferred(ArrayInterface.is_dense(A)) === @inferred(ArrayInterface.is_dense(A)) === @inferred(ArrayInterface.is_dense(PermutedDimsArray(A,(3,1,2)))) === @inferred(ArrayInterface.is_dense(Array{Float64,0}(undef))) === True()
426+
@test @inferred(ArrayInterface.is_dense(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) === @inferred(ArrayInterface.is_dense(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) === @inferred(ArrayInterface.is_dense(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) === False()
425427

428+
426429
C = Array{Int8}(undef, 2,2,2,2);
427430
doubleperm = PermutedDimsArray(PermutedDimsArray(C,(4,2,3,1)), (4,2,1,3));
428431
@test collect(strides(C))[collect(stride_rank(doubleperm))] == collect(strides(doubleperm))

0 commit comments

Comments
 (0)