Skip to content

Commit ff185b7

Browse files
committed
refine and cleanup handling of range arithmetic
Try to be more careful about which types we use for arguments and return values and comparisons in intermediate computations. Not expected to change nominal behaviors, but may improve some unusual ranges that require some conversions or are near over/underflow. And use convert(T,1) rather than oneunit(T) to support fewer types, as we want the default step to be a unitless 1 (e.g., not Nanosecond(1)). Replaces #43058
1 parent bdf9ead commit ff185b7

File tree

3 files changed

+168
-129
lines changed

3 files changed

+168
-129
lines changed

base/range.jl

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
66

7-
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)
7+
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)
88

99
# promote start and stop, leaving step alone
1010
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
@@ -164,7 +164,7 @@ _range(start::Any , step::Any , stop::Any , len::Any ) = range_error
164164
range_length(len::Integer) = OneTo(len)
165165

166166
# Stop as the only argument
167-
range_stop(stop) = range_start_stop(oneunit(stop), stop)
167+
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
168168
range_stop(stop::Integer) = range_length(stop)
169169

170170
# Stop and length as the only argument
@@ -200,10 +200,17 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
200200
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
201201
end
202202

203-
_rangestyle(::Ordered, ::ArithmeticWraps, a::T, step::S, len::Integer) where {T,S} =
204-
StepRange{typeof(a+zero(step)),S}(a, step, a+step*(len-1))
205-
_rangestyle(::Any, ::Any, a::T, step::S, len::Integer) where {T,S} =
206-
StepRangeLen{typeof(a+zero(step)),T,S}(a, step, len)
203+
function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
204+
start = a + zero(step)
205+
stop = a + step * (len - 1)
206+
T = typeof(start)
207+
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
208+
end
209+
function _rangestyle(::Any, ::Any, a, step, len::Integer)
210+
start = a + zero(step)
211+
T = typeof(a)
212+
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
213+
end
207214

208215
range_start_step_stop(start, step, stop) = start:step:stop
209216

@@ -306,19 +313,19 @@ struct StepRange{T,S} <: OrdinalRange{T,S}
306313
stop::T
307314

308315
function StepRange{T,S}(start, step, stop) where {T,S}
309-
sta = convert(T, start)
310-
ste = convert(S, step)
311-
sto = convert(T, stop)
312-
new(sta, ste, steprange_last(sta,ste,sto))
316+
start = convert(T, start)
317+
step = convert(S, step)
318+
stop = convert(T, stop)
319+
return new(start, step, steprange_last(start, step, stop))
313320
end
314321
end
315322

316323
# to make StepRange constructor inlineable, so optimizer can see `step` value
317-
function steprange_last(start::T, step, stop) where T
318-
if isa(start,AbstractFloat) || isa(step,AbstractFloat)
324+
function steprange_last(start, step, stop)::typeof(stop)
325+
if isa(start, AbstractFloat) || isa(step, AbstractFloat)
319326
throw(ArgumentError("StepRange should not be used with floating point"))
320327
end
321-
if isa(start,Integer) && !isinteger(start + step)
328+
if isa(start, Integer) && !isinteger(start + step)
322329
throw(ArgumentError("StepRange{<:Integer} cannot have non-integer step"))
323330
end
324331
z = zero(step)
@@ -335,30 +342,28 @@ function steprange_last(start::T, step, stop) where T
335342
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)
336343

337344
# Compute remainder as a nonnegative number:
338-
if T <: Signed && absdiff < zero(absdiff)
339-
# handle signed overflow with unsigned rem
340-
remain = convert(T, unsigned(absdiff) % absstep)
345+
if absdiff isa Signed && absdiff < zero(absdiff)
346+
# unlikely, but handle the signed overflow case with unsigned rem
347+
remain = convert(typeof(absdiff), unsigned(absdiff) % absstep)
341348
else
342-
remain = absdiff % absstep
349+
remain = convert(typeof(absdiff), absdiff % absstep)
343350
end
344351
# Move `stop` closer to `start` if there is a remainder:
345352
last = stop > start ? stop - remain : stop + remain
346353
end
347354
end
348-
last
355+
return last
349356
end
350357

351-
function steprange_last_empty(start::Integer, step, stop)
352-
# empty range has a special representation where stop = start-1
353-
# this is needed to avoid the wrap-around that can happen computing
354-
# start - step, which leads to a range that looks very large instead
355-
# of empty.
358+
function steprange_last_empty(start::Integer, step, stop)::typeof(stop)
359+
# empty range has a special representation where stop = start-1,
360+
# which simplifies arithmetic for Signed numbers
356361
if step > zero(step)
357-
last = start - oneunit(stop-start)
362+
last = start - oneunit(step)
358363
else
359-
last = start + oneunit(stop-start)
364+
last = start + oneunit(step)
360365
end
361-
last
366+
return last
362367
end
363368
# For types where x+oneunit(x) may not be well-defined use the user-given value for stop
364369
steprange_last_empty(start, step, stop) = stop
@@ -388,18 +393,21 @@ UnitRange{Int64}
388393
struct UnitRange{T<:Real} <: AbstractUnitRange{T}
389394
start::T
390395
stop::T
391-
UnitRange{T}(start, stop) where {T<:Real} = new(start, unitrange_last(start,stop))
396+
UnitRange{T}(start::T, stop::T) where {T<:Real} = new(start, unitrange_last(start, stop))
392397
end
398+
UnitRange{T}(start, stop) where {T<:Real} = UnitRange{T}(convert(T, start), convert(T, stop))
393399
UnitRange(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
400+
UnitRange(start, stop) = UnitRange(promote(start, stop)...)
394401

395-
unitrange_last(::Bool, stop::Bool) = stop
396-
unitrange_last(start::T, stop::T) where {T<:Integer} =
397-
stop >= start ? stop : convert(T,start-oneunit(start-stop))
398-
unitrange_last(start::T, stop::T) where {T} =
399-
stop >= start ? convert(T,start+floor(stop-start)) :
400-
convert(T,start-oneunit(stop-start))
402+
# if stop and start are integral, we know that their difference is a multiple of 1
403+
unitrange_last(start::Integer, stop::Integer) =
404+
stop >= start ? stop : convert(typeof(stop), start - oneunit(start - stop))
405+
# otherwise, use `floor` as a more efficient way to compute modulus with step=1
406+
unitrange_last(start, stop) =
407+
stop >= start ? convert(typeof(stop), start + floor(stop - start)) :
408+
convert(typeof(stop), start - oneunit(start - stop))
401409

402-
unitrange(x) = UnitRange(x)
410+
unitrange(x::AbstractUnitRange) = UnitRange(x) # convenience conversion for promoting the range type
403411

404412
if isdefined(Main, :Base)
405413
# Constant-fold-able indexing into tuples to functionally expose Base.tail and Base.front
@@ -556,7 +564,7 @@ function LinRange{T}(start, stop, len::Integer) where T
556564
end
557565

558566
function LinRange(start, stop, len::Integer)
559-
T = typeof((stop-start)/len)
567+
T = typeof((zero(stop) - zero(start)) / oneunit(len))
560568
LinRange{T}(start, stop, len)
561569
end
562570

@@ -642,7 +650,7 @@ length(r::AbstractRange) = error("length implementation missing") # catch mistak
642650
size(r::AbstractRange) = (length(r),)
643651

644652
isempty(r::StepRange) =
645-
# steprange_last_empty(r.start, r.step, r.stop) == r.stop
653+
# steprange_last(r.start, r.step, r.stop) == r.stop
646654
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
647655
isempty(r::AbstractUnitRange) = first(r) > last(r)
648656
isempty(r::StepRangeLen) = length(r) == 0
@@ -689,9 +697,8 @@ firstindex(::LinRange) = 1
689697
# defined between the relevant types
690698
function checked_length(r::OrdinalRange{T}) where T
691699
s = step(r)
692-
# s != 0, by construction, but avoids the division error later
693700
start = first(r)
694-
if s == zero(s) || isempty(r)
701+
if isempty(r)
695702
return Integer(div(start - start, oneunit(s)))
696703
end
697704
stop = last(r)
@@ -716,9 +723,8 @@ end
716723

717724
function length(r::OrdinalRange{T}) where T
718725
s = step(r)
719-
# s != 0, by construction, but avoids the division error later
720726
start = first(r)
721-
if s == zero(s) || isempty(r)
727+
if isempty(r)
722728
return Integer(div(start - start, oneunit(s)))
723729
end
724730
stop = last(r)
@@ -756,7 +762,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
756762
# (near typemax) for types with known `unsigned` functions
757763
function length(r::OrdinalRange{T}) where T<:bigints
758764
s = step(r)
759-
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
760765
isempty(r) && return zero(T)
761766
diff = last(r) - first(r)
762767
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
@@ -773,7 +778,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
773778
end
774779
function checked_length(r::OrdinalRange{T}) where T<:bigints
775780
s = step(r)
776-
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
777781
isempty(r) && return zero(T)
778782
stop, start = last(r), first(r)
779783
# n.b. !(s isa T)
@@ -800,7 +804,6 @@ let smallints = (Int === Int64 ?
800804
# n.b. !(step isa T)
801805
function length(r::OrdinalRange{<:smallints})
802806
s = step(r)
803-
s == zero(s) && return 0 # unreachable, by construction, but avoids the error case here later
804807
isempty(r) && return 0
805808
return div(Int(last(r)) - Int(first(r)), s) + 1
806809
end
@@ -962,29 +965,30 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
962965
@boundscheck checkbounds(r, s)
963966

964967
if T === Bool
965-
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
968+
return range(first(s) ? first(r) : last(r), length = last(s))
966969
else
967970
f = first(r)
968-
st = oftype(f, f + first(s)-firstindex(r))
969-
return range(st, length=length(s))
971+
start = oftype(f, f + first(s)-firstindex(r))
972+
return range(start, length=length(s))
970973
end
971974
end
972975

973976
function getindex(r::OneTo{T}, s::OneTo) where T
974977
@inline
975978
@boundscheck checkbounds(r, s)
976-
OneTo(T(s.stop))
979+
return OneTo(T(s.stop))
977980
end
978981

979982
function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
980983
@inline
981984
@boundscheck checkbounds(r, s)
982985

983986
if T === Bool
984-
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
987+
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
985988
else
986-
st = oftype(first(r), first(r) + s.start-firstindex(r))
987-
return range(st, step=step(s), length=length(s))
989+
f = first(r)
990+
start = oftype(f, f + s.start-firstindex(r))
991+
return range(start, step=step(s), length=length(s))
988992
end
989993
end
990994

@@ -994,19 +998,22 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
994998

995999
if T === Bool
9961000
if length(s) == 0
997-
return range(first(r), step=step(r), length=0)
1001+
start, len = first(r), 0
9981002
elseif length(s) == 1
9991003
if first(s)
1000-
return range(first(r), step=step(r), length=1)
1004+
start, len = first(r), 1
10011005
else
1002-
return range(first(r), step=step(r), length=0)
1006+
start, len = first(r), 0
10031007
end
10041008
else # length(s) == 2
1005-
return range(last(r), step=step(r), length=1)
1009+
start, len = last(r), 1
10061010
end
1011+
return range(start, step=step(r); length=len)
10071012
else
1008-
st = oftype(r.start, r.start + (first(s)-1)*step(r))
1009-
return range(st, step=step(r)*step(s), length=length(s))
1013+
f = r.start
1014+
st = r.step
1015+
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
1016+
return range(start; step=st*step(s), length=length(s))
10101017
end
10111018
end
10121019

@@ -1235,7 +1242,7 @@ end
12351242
issubset(r::OneTo, s::OneTo) = r.stop <= s.stop
12361243

12371244
issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
1238-
isempty(r) || first(r) >= first(s) && last(r) <= last(s)
1245+
isempty(r) || (first(r) >= first(s) && last(r) <= last(s))
12391246

12401247
## linear operations on ranges ##
12411248

stdlib/Dates/src/ranges.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ end
2424
Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.stop, r.step) + 1
2525
# Period ranges hook into Int64 overflow detection
2626
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))
27+
Base.checked_length(r::StepRange{<:Period}) = Base.checked_length(StepRange(value(r.start), value(r.step), value(r.stop)))
2728

28-
# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
29+
# Overload Base.steprange_last because `step::Period` may be a variable amount of time (e.g. for Month and Year)
2930
function Base.steprange_last(start::T, step, stop) where T<:TimeType
30-
if isa(step,AbstractFloat)
31+
if isa(step, AbstractFloat)
3132
throw(ArgumentError("StepRange should not be used with floating point"))
3233
end
3334
z = zero(step)
@@ -47,7 +48,7 @@ function Base.steprange_last(start::T, step, stop) where T<:TimeType
4748
last = stop - remain
4849
end
4950
end
50-
last
51+
return last
5152
end
5253

5354
import Base.in

0 commit comments

Comments
 (0)