Skip to content

Commit cf9b7a8

Browse files
authored
make ReinterpretArray more Offset-safe (#58898)
1 parent 55e2bd7 commit cf9b7a8

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

base/reinterpretarray.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ SimdLoop.simd_inner_length(::SCartesianIndices2{K}, ::Any) where K = K
311311
SCartesianIndex2{K}(I1+1, Ilast)
312312
end
313313

314+
_maybe_reshape(::IndexSCartesian2, A::AbstractArray, I...) = _maybe_reshape(IndexCartesian(), A, I...)
314315
_maybe_reshape(::IndexSCartesian2, A::ReshapedReinterpretArray, I...) = A
315316

316317
# fallbacks
@@ -329,11 +330,25 @@ function _getindex(::IndexSCartesian2, A::AbstractArray{T,N}, ind::SCartesianInd
329330
J = _ind2sub(tail(axes(A)), ind.j)
330331
getindex(A, ind.i, J...)
331332
end
333+
334+
function _getindex(::IndexSCartesian2{2}, A::AbstractArray{T,2}, ind::SCartesianIndex2) where {T}
335+
@_propagate_inbounds_meta
336+
J = first(axes(A, 2)) + ind.j - 1
337+
getindex(A, ind.i, J)
338+
end
339+
332340
function _setindex!(::IndexSCartesian2, A::AbstractArray{T,N}, v, ind::SCartesianIndex2) where {T,N}
333341
@_propagate_inbounds_meta
334342
J = _ind2sub(tail(axes(A)), ind.j)
335343
setindex!(A, v, ind.i, J...)
336344
end
345+
346+
function _setindex!(::IndexSCartesian2{2}, A::AbstractArray{T,2}, v, ind::SCartesianIndex2) where {T}
347+
@_propagate_inbounds_meta
348+
J = first(axes(A, 2)) + ind.j - 1
349+
setindex!(A, v, ind.i, J)
350+
end
351+
337352
eachindex(style::IndexSCartesian2, A::AbstractArray) = eachindex(style, parent(A))
338353

339354
## AbstractArray interface

test/reinterpretarray.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@ tslow(a::AbstractArray) = TSlow(a)
1010
wrapper(a::AbstractArray) = WrapperArray(a)
1111
fcviews(a::AbstractArray) = view(a, ntuple(Returns(:),ndims(a)-1)..., axes(a)[end])
1212
fcviews(a::AbstractArray{<:Any, 0}) = view(a)
13+
offset_nominal(a::AbstractArray) = OffsetArray(a)
14+
offset_maybe(a::AbstractArray) = (eltype(a) <: Real) ? a : OffsetArray(a, (1-ndims(A)):2:(ndims(A)-1)...)
1315
tslow(t::Tuple) = map(tslow, t)
1416
wrapper(t::Tuple) = map(wrapper, t)
1517
fcviews(t::Tuple) = map(fcviews, t)
18+
offset_nominal(t::Tuple) = map(offset_nominal, t)
19+
offset_maybe(t::Tuple) = map(offset_maybe, t)
1620

1721
test_many_wrappers(testf, A, wrappers) = foreach(w -> testf(w(A)), wrappers)
18-
test_many_wrappers(testf, A) = test_many_wrappers(testf, A, (identity, tslow, wrapper, fcviews))
22+
test_many_wrappers(testf, A) = test_many_wrappers(
23+
testf, A, (identity, tslow, wrapper, fcviews, offset_nominal, offset_maybe)
24+
)
1925

2026
A = Int64[1, 2, 3, 4]
2127
Ars = Int64[1 3; 2 4]
@@ -37,10 +43,6 @@ test_many_wrappers(B, (identity, tslow)) do _B
3743
@test @inferred(size(reinterpret(reshape, Int128, _B))) == (3,)
3844
end
3945

40-
test_many_wrappers(C) do Cr
41-
@test reinterpret(reshape, Tuple{Int8, Int}, Cr) == fill((1,1))
42-
end
43-
4446
@test_throws ArgumentError("cannot reinterpret `Int64` as `Vector{Int64}`, type `Vector{Int64}` is not a bits type") reinterpret(Vector{Int64}, A)
4547
@test_throws ArgumentError("cannot reinterpret `Vector{Int32}` as `Int32`, type `Vector{Int32}` is not a bits type") reinterpret(Int32, Av)
4648
@test_throws ArgumentError("cannot reinterpret a zero-dimensional `Int64` array to `Int32` which is of a different size") reinterpret(Int32, reshape([Int64(0)]))
@@ -160,8 +162,10 @@ test_many_wrappers(A3) do A3_
160162
@test A3[2,1,2] == 400
161163
end
162164

163-
test_many_wrappers(C) do Cr
165+
test_many_wrappers(C) do Cr_
166+
Cr = deepcopy(Cr_)
164167
r = reinterpret(reshape, Tuple{Int, Int}, Cr)
168+
@test r == fill((1,1))
165169
r[] = (2,2)
166170
@test r[] === (2,2)
167171
r[1] = (3,3)
@@ -378,6 +382,8 @@ let a = rand(ComplexF32, 5)
378382
r = reinterpret(reshape, Float32, a)
379383
ref = Array(r)
380384

385+
@test all(r .== OffsetArray(r)[:, :, :])
386+
381387
@test r[1, :, 1] == ref[1, :]
382388
@test r[1, :, 1, 1, 1] == ref[1, :]
383389
@test r[1, :, UInt8(1)] == ref[1, :]

0 commit comments

Comments
 (0)