Skip to content

Commit 22490b7

Browse files
committed
Fix strides for Adjoint/Transpose of vectors
The previous method took the first stride of the parent vector and just doubled it. This copies what base does.
1 parent 2680f66 commit 22490b7

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/stridelayout.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,10 @@ end
418418
@inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...)
419419
@inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A))
420420

421-
@inline function strides(x::LinearAlgebra.Adjoint{T,V}) where {T,V<:AbstractVector{T}}
422-
strd = stride(parent(x), One())
423-
return (strd, strd)
424-
end
425-
@inline function strides(x::LinearAlgebra.Transpose{T,V}) where {T,V<:AbstractVector{T}}
426-
strd = stride(parent(x), One())
427-
return (strd, strd)
421+
function strides(x::VecAdjTrans)
422+
p = parent(x)
423+
st = first(strides(p))
424+
return (static_length(p) * st, st)
428425
end
429426

430427
@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::StaticInt{C}) where {T,N,C}

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,13 @@ end
480480
@test @inferred(ArrayInterface.strides(S)) === (StaticInt(1), StaticInt(2), StaticInt(6))
481481
@test @inferred(ArrayInterface.strides(Sp)) === (StaticInt(6), StaticInt(1), StaticInt(2))
482482
@test @inferred(ArrayInterface.strides(Sp2)) === (StaticInt(6), StaticInt(2), StaticInt(1))
483+
484+
@test @inferred(ArrayInterface.strides(view(Sp2, :, 1, 1)')) === (12, StaticInt(6))
485+
483486
@test @inferred(ArrayInterface.stride(Sp2, StaticInt(1))) === StaticInt(6)
484487
@test @inferred(ArrayInterface.stride(Sp2, StaticInt(2))) === StaticInt(2)
485488
@test @inferred(ArrayInterface.stride(Sp2, StaticInt(3))) === StaticInt(1)
486-
489+
487490
@test @inferred(ArrayInterface.strides(M)) === (StaticInt(1), StaticInt(2), StaticInt(6))
488491
@test @inferred(ArrayInterface.strides(Mp)) === (StaticInt(2), StaticInt(6))
489492
@test @inferred(ArrayInterface.strides(Mp2)) === (StaticInt(1), StaticInt(6))

0 commit comments

Comments
 (0)