Skip to content

Commit a68ae8e

Browse files
authored
Merge pull request #72 from Tokazama/OptionallyStaticUnitRange-fixes
Fixe BoundsError and length on OptionallyStaticUnitRange
2 parents 1e1f4a5 + fb321e2 commit a68ae8e

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.13.0"
3+
version = "2.13.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ranges.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,23 @@ unsafe_length_one_to(lst::Int) = lst
105105
unsafe_length_one_to(::StaticInt{L}) where {L} = lst
106106

107107
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
108-
if known_first(r) === oneunit(r)
108+
if known_first(r) === oneunit(eltype(r))
109109
return get_index_one_to(r, i)
110110
else
111111
return get_index_unit_range(r, i)
112112
end
113113
end
114114

115115
@inline function get_index_one_to(r, i)
116-
@boundscheck if ((i > 0) & (i <= last(r)))
116+
@boundscheck if ((i < 1) || (i > last(r)))
117117
throw(BoundsError(r, i))
118118
end
119119
return convert(eltype(r), i)
120120
end
121121

122122
@inline function get_index_unit_range(r, i)
123123
val = first(r) + (i - 1)
124-
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
124+
@boundscheck if (i < 1) || (val > last(r) && val < first(r))
125125
throw(BoundsError(r, i))
126126
end
127127
return convert(eltype(r), val)
@@ -169,15 +169,15 @@ function Base.length(r::OptionallyStaticUnitRange)
169169
if isempty(r)
170170
return 0
171171
else
172-
if known_first(r) === 0
172+
if known_first(r) === 1
173173
return unsafe_length_one_to(last(r))
174174
else
175175
return unsafe_length_unit_range(first(r), last(r))
176176
end
177177
end
178178
end
179179

180-
unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1)
180+
unsafe_length_unit_range(start::Integer, stop::Integer) = Int((stop - start) + 1)
181181

182182
"""
183183
indices(x[, d])

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ end
200200
@test isnothing(@inferred(ArrayInterface.known_step(typeof(1:0.2:4))))
201201
@test isone(@inferred(ArrayInterface.known_step(1:4)))
202202
@test isone(@inferred(ArrayInterface.known_step(typeof(1:4))))
203+
204+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 0))) == 0
205+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 10))) == 10
206+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10))) == 10
207+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), 10))) == 11
208+
@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 1)) == 1
209+
@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), 10), 1)) == 0
210+
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 0)
211+
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 0)
203212
end
204213

205214
@testset "Memory Layout" begin

0 commit comments

Comments
 (0)