Skip to content

Commit cb70983

Browse files
committed
Add some ReshapedArray of adjoint vector methods; somewhat hacky (special-case) fix for #215. Add a test.
1 parent b5fab6b commit cb70983

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/stridelayout.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ function contiguous_axis(::Type{T}) where {T<:PermutedDimsArray}
151151
return from_parent_dims(T, c)
152152
end
153153
end
154+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, A, Tuple{}}}) where {T, A}
155+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
156+
end
157+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Adjoint{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
158+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
159+
end
160+
function contiguous_axis(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
161+
IfElse.ifelse(is_column_major(A) & is_dense(A), static(1), nothing)
162+
end
163+
function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Transpose{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
164+
IfElse.ifelse(is_dense(A), (static(1),), nothing)
165+
end
154166
function contiguous_axis(::Type{T}) where {T<:SubArray}
155167
return _contiguous_axis(T, contiguous_axis(parent_type(T)))
156168
end
@@ -267,6 +279,13 @@ end
267279
function stride_rank(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
268280
_reshaped_striderank(is_column_major(P), Val{N}(), Val{M}())
269281
end
282+
function stride_rank(::Type{Base.ReshapedArray{T, 1, A, Tuple{}}}) where {T, A}
283+
IfElse.ifelse(is_column_major(A) & is_dense(A), (static(1),), nothing)
284+
end
285+
function stride_rank(::Type{Base.ReshapedArray{T, 1, LinearAlgebra.Adjoint{T, A}, Tuple{}}}) where {T, A <: AbstractVector{T}}
286+
IfElse.ifelse(is_dense(A), (static(1),), nothing)
287+
end
288+
270289
_reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N))
271290
_reshaped_striderank(_, __, ___) = nothing
272291

@@ -425,6 +444,14 @@ end
425444
function dense_dims(::Type{Base.ReshapedArray{T, N, P, Tuple{Vararg{Base.SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
426445
return _reshaped_dense_dims(dense_dims(P), is_column_major(P), Val{N}(), Val{M}())
427446
end
447+
is_dense(A) = is_dense(typeof(A))
448+
is_dense(::Type{A}) where {A} = _is_dense(dense_dims(A))
449+
_is_dense(::Tuple{False,Vararg}) = False()
450+
_is_dense(t::Tuple{True,Vararg}) = _is_dense(Base.tail(t))
451+
_is_dense(t::Tuple{True}) = True()
452+
_is_dense(t::Tuple{}) = True()
453+
454+
428455
_reshaped_dense_dims(_, __, ___, ____) = nothing
429456
function _reshaped_dense_dims(dense::D, ::True, ::Val{N}, ::Val{0}) where {D,N}
430457
if all(dense)

test/array_index.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,8 @@ end
1818
@test @inferred(ArrayInterface.contiguous_axis(ArrayInterface.StrideIndex{2,(1,2),nothing,NTuple{2,Int},NTuple{2,Int}})) == nothing
1919
@test @inferred(ArrayInterface.stride_rank(ap_index)) == (1, 3)
2020

21+
let v = Float64.(1:10)'
22+
sv = @view(v[1:5])'
23+
ArrayInterface.StrideIndex(sv) === ArrayInterface.StrideIndex{2, (2, 1), 2}((StaticInt(1), StaticInt(1)), (StaticInt(1), StaticInt(1)))
24+
end
25+

0 commit comments

Comments
 (0)