Skip to content

Commit b1ce7bb

Browse files
authored
Merge pull request #301 from JuliaArrays/nonreshapereinterpretarray
fix non-reshapred reinterpretarray
2 parents 5569c5e + 0e7b4e4 commit b1ce7bb

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.14"
3+
version = "6.0.15"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/stridelayout.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,15 @@ end
261261
@inline function stride_rank(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}
262262
NA == NB ? stride_rank(B) : _stride_rank_reinterpret(stride_rank(B), gt(StaticInt{NB}(), StaticInt{NA}()))
263263
end
264+
@inline function stride_rank(::Type{A}) where {N,B<:AbstractArray{<:Any,N},A<:Base.ReinterpretArray{<:Any,N,<:Any,B,false}}
265+
stride_rank(B)
266+
end
267+
264268
@inline _stride_rank_reinterpret(sr, ::False) = (One(), map(Base.Fix2(+, One()), sr)...)
265269
@inline _stride_rank_reinterpret(sr::Tuple{One,Vararg}, ::True) = map(Base.Fix2(-, One()), tail(sr))
270+
function contiguous_axis(::Type{R}) where {T,N,S,B<:AbstractArray{S,N},R<:ReinterpretArray{T,N,S,B,false}}
271+
contiguous_axis(B)
272+
end
266273
# if the leading dim's `stride_rank` is not one, then that means the individual elements are split across an axis, which ArrayInterface
267274
# doesn't currently have a means of representing.
268275
@inline function contiguous_axis(::Type{A}) where {NB,NA,B<:AbstractArray{<:Any,NB},A<:Base.ReinterpretArray{<:Any,NA,<:Any,B,true}}

test/stridelayout.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ end
108108
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
109109
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(MArray(rand(ComplexF64, 5, 7)), 2:4, 3:6));
110110
@test @inferred(ArrayInterface.strides(Ac2t_static)) === (StaticInt(1), StaticInt(5))
111+
112+
a = rand(Float32, 100, 2);
113+
b = reinterpret(Float64, view(a,:,1));
114+
@test @inferred(ArrayInterface.contiguous_axis(a)) === StaticInt(1)
115+
@test @inferred(ArrayInterface.stride_rank(b)) === (StaticInt(1),)
111116
end
112117

113118
@testset "Memory Layout" begin

0 commit comments

Comments
 (0)