Skip to content

Commit 713b86e

Browse files
committed
Support atsign-unsafe indexing with trailing 1s
1 parent ba85845 commit 713b86e

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

src/OffsetArrays.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ offset{N}(offsets::NTuple{N,Int}, inds::NTuple{N,Int}) = _offset((), offsets, in
121121
_offset(out, ::Tuple{}, ::Tuple{}) = out
122122
@inline _offset(out, offsets, inds) = _offset((out..., inds[1]-offsets[1]), Base.tail(offsets), Base.tail(inds))
123123

124+
# Support trailing 1s
125+
@inline offset(offsets::Tuple{Vararg{Int}}, inds::Tuple{Vararg{Int}}) = (offset(offsets, Base.front(inds))..., inds[end])
126+
offset(offsets::Tuple{}, inds::Tuple{}) = ()
127+
offset(offsets::Tuple{Vararg{Int}}, inds::Tuple{}) = error("inds cannot be shorter than offsets")
128+
124129
indexoffset(r::Range) = first(r) - 1
125130
indexoffset(i::Integer) = 0
126131

@@ -152,6 +157,34 @@ end
152157
@inline unsafe_getindex(a::AbstractArray, I...) = (@inbounds ret = a[I...]; ret)
153158
@inline unsafe_setindex!(a::AbstractArray, val, I...) = (@inbounds a[I...] = val; val)
154159

160+
# Linear indexing
161+
@inline unsafe_getindex(a::OffsetArray, i::Int) = _unsafe_getindex(Base.linearindexing(a), a, i)
162+
@inline unsafe_setindex!(a::OffsetArray, val, i::Int) = _unsafe_setindex!(Base.linearindexing(a), a, val, i)
163+
for T in (LinearFast, LinearSlow) # ambiguity-resolution requires specificity for both
164+
@eval begin
165+
@inline function _unsafe_getindex(::$T, a::OffsetVector, i::Int)
166+
@inbounds ret = parent(a)[offset(a.offsets, (i,))[1]]
167+
ret
168+
end
169+
@inline function _unsafe_setindex!(::$T, a::OffsetVector, val, i::Int)
170+
@inbounds parent(a)[offset(a.offsets, (i,))[1]] = val
171+
val
172+
end
173+
end
174+
end
175+
@inline function _unsafe_getindex(::LinearFast, a::OffsetArray, i::Int)
176+
@inbounds ret = parent(a)[i]
177+
ret
178+
end
179+
@inline _unsafe_getindex(::LinearSlow, a::OffsetArray, i::Int) =
180+
unsafe_getindex(a, ind2sub(indices(a), i)...)
181+
@inline function _unsafe_setindex!(::LinearFast, a::OffsetArray, val, i::Int)
182+
@inbounds parent(a)[i] = val
183+
val
184+
end
185+
@inline _unsafe_setindex!(::LinearSlow, a::OffsetArray, val, i::Int) =
186+
unsafe_setindex!(a, val, ind2sub(indices(a), i)...)
187+
155188
@inline unsafe_getindex(a::OffsetArray, I::Int...) = unsafe_getindex(parent(a), offset(a.offsets, I)...)
156189
@inline unsafe_setindex!(a::OffsetArray, val, I::Int...) = unsafe_setindex!(parent(a), val, offset(a.offsets, I)...)
157190
@inline unsafe_getindex(a::OffsetArray, I...) = unsafe_getindex(a, Base.IteratorsMD.flatten(I)...)

test/runtests.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,25 @@ S = OffsetArray(view(A0, 1:2, 1:2), (-1,2)) # LinearSlow
3939
@test_throws DimensionMismatch OffsetArray(A0, 0:1, 2:4)
4040

4141
# Scalar indexing
42-
@test A[0,3] == A[1] == S[0,3] == S[1] == 1
43-
@test A[1,3] == A[2] == S[1,3] == S[2] == 2
44-
@test A[0,4] == A[3] == S[0,4] == S[3] == 3
45-
@test A[1,4] == A[4] == S[1,4] == S[4] == 4
42+
@test A[0,3] == A[0,3,1] == A[1] == S[0,3] == S[0,3,1] == S[1] == 1
43+
@test A[1,3] == A[1,3,1] == A[2] == S[1,3] == S[1,3,1] == S[2] == 2
44+
@test A[0,4] == A[0,4,1] == A[3] == S[0,4] == S[0,4,1] == S[3] == 3
45+
@test A[1,4] == A[1,4,1] == A[4] == S[1,4] == S[1,4,1] == S[4] == 4
46+
@test @unsafe(A[0,3]) == @unsafe(A[0,3,1]) == @unsafe(A[1]) == @unsafe(S[0,3]) == @unsafe(S[0,3,1]) == @unsafe(S[1]) == 1
47+
@test @unsafe(A[1,3]) == @unsafe(A[1,3,1]) == @unsafe(A[2]) == @unsafe(S[1,3]) == @unsafe(S[1,3,1]) == @unsafe(S[2]) == 2
48+
@test @unsafe(A[0,4]) == @unsafe(A[0,4,1]) == @unsafe(A[3]) == @unsafe(S[0,4]) == @unsafe(S[0,4,1]) == @unsafe(S[3]) == 3
49+
@test @unsafe(A[1,4]) == @unsafe(A[1,4,1]) == @unsafe(A[4]) == @unsafe(S[1,4]) == @unsafe(S[1,4,1]) == @unsafe(S[4]) == 4
4650
@test_throws BoundsError A[1,1]
4751
@test_throws BoundsError S[1,1]
52+
@test_throws BoundsError A[0,3,2]
53+
@test_throws BoundsError A[0,3,0]
54+
Ac = copy(A)
55+
Ac[0,3] = 10
56+
@test Ac[0,3] == 10
57+
Ac[0,3,1] = 11
58+
@test Ac[0,3] == 11
59+
@unsafe Ac[0,3,1] = 12
60+
@test Ac[0,3] == 12
4861

4962
# Vector indexing
5063
@test A[:, 3] == S[:, 3] == OffsetArray([1,2], (A.offsets[1],))
@@ -63,8 +76,15 @@ S = OffsetArray(view(A0, 1:2, 1:2), (-1,2)) # LinearSlow
6376

6477
# CartesianIndexing
6578
@test A[CartesianIndex((0,3))] == S[CartesianIndex((0,3))] == 1
79+
@test A[CartesianIndex((0,3)),1] == S[CartesianIndex((0,3)),1] == 1
80+
@test @unsafe(A[CartesianIndex((0,3))]) == @unsafe(S[CartesianIndex((0,3))]) == 1
81+
@test @unsafe(A[CartesianIndex((0,3)),1]) == @unsafe(S[CartesianIndex((0,3)),1]) == 1
6682
@test_throws BoundsError A[CartesianIndex(1,1)]
83+
@test_throws BoundsError A[CartesianIndex(1,1),0]
84+
@test_throws BoundsError A[CartesianIndex(1,1),2]
6785
@test_throws BoundsError S[CartesianIndex(1,1)]
86+
@test_throws BoundsError S[CartesianIndex(1,1),0]
87+
@test_throws BoundsError S[CartesianIndex(1,1),2]
6888
@test eachindex(A) == 1:4
6989
@test eachindex(S) == CartesianRange((0:1,3:4))
7090

0 commit comments

Comments
 (0)