Skip to content

Commit 5d41a76

Browse files
authored
Merge pull request #43360 from JuliaLang/jn/ranges-last
refine and cleanup handling of range arithmetic
2 parents b6bca19 + ff185b7 commit 5d41a76

File tree

5 files changed

+174
-132
lines changed

5 files changed

+174
-132
lines changed

base/range.jl

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
66

7-
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)
7+
(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)
88

99
# promote start and stop, leaving step alone
1010
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
@@ -164,7 +164,7 @@ _range(start::Any , step::Any , stop::Any , len::Any ) = range_error
164164
range_length(len::Integer) = OneTo(len)
165165

166166
# Stop as the only argument
167-
range_stop(stop) = range_start_stop(oneunit(stop), stop)
167+
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
168168
range_stop(stop::Integer) = range_length(stop)
169169

170170
# Stop and length as the only argument
@@ -200,10 +200,17 @@ function range_start_step_length(a::T, step, len::Integer) where {T}
200200
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
201201
end
202202

203-
_rangestyle(::Ordered, ::ArithmeticWraps, a::T, step::S, len::Integer) where {T,S} =
204-
StepRange{typeof(a+zero(step)),S}(a, step, a+step*(len-1))
205-
_rangestyle(::Any, ::Any, a::T, step::S, len::Integer) where {T,S} =
206-
StepRangeLen{typeof(a+zero(step)),T,S}(a, step, len)
203+
function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
204+
start = a + zero(step)
205+
stop = a + step * (len - 1)
206+
T = typeof(start)
207+
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
208+
end
209+
function _rangestyle(::Any, ::Any, a, step, len::Integer)
210+
start = a + zero(step)
211+
T = typeof(a)
212+
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
213+
end
207214

208215
range_start_step_stop(start, step, stop) = start:step:stop
209216

@@ -306,19 +313,19 @@ struct StepRange{T,S} <: OrdinalRange{T,S}
306313
stop::T
307314

308315
function StepRange{T,S}(start, step, stop) where {T,S}
309-
sta = convert(T, start)
310-
ste = convert(S, step)
311-
sto = convert(T, stop)
312-
new(sta, ste, steprange_last(sta,ste,sto))
316+
start = convert(T, start)
317+
step = convert(S, step)
318+
stop = convert(T, stop)
319+
return new(start, step, steprange_last(start, step, stop))
313320
end
314321
end
315322

316323
# to make StepRange constructor inlineable, so optimizer can see `step` value
317-
function steprange_last(start::T, step, stop) where T
318-
if isa(start,AbstractFloat) || isa(step,AbstractFloat)
324+
function steprange_last(start, step, stop)::typeof(stop)
325+
if isa(start, AbstractFloat) || isa(step, AbstractFloat)
319326
throw(ArgumentError("StepRange should not be used with floating point"))
320327
end
321-
if isa(start,Integer) && !isinteger(start + step)
328+
if isa(start, Integer) && !isinteger(start + step)
322329
throw(ArgumentError("StepRange{<:Integer} cannot have non-integer step"))
323330
end
324331
z = zero(step)
@@ -335,30 +342,28 @@ function steprange_last(start::T, step, stop) where T
335342
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)
336343

337344
# Compute remainder as a nonnegative number:
338-
if T <: Signed && absdiff < zero(absdiff)
339-
# handle signed overflow with unsigned rem
340-
remain = convert(T, unsigned(absdiff) % absstep)
345+
if absdiff isa Signed && absdiff < zero(absdiff)
346+
# unlikely, but handle the signed overflow case with unsigned rem
347+
remain = convert(typeof(absdiff), unsigned(absdiff) % absstep)
341348
else
342-
remain = absdiff % absstep
349+
remain = convert(typeof(absdiff), absdiff % absstep)
343350
end
344351
# Move `stop` closer to `start` if there is a remainder:
345352
last = stop > start ? stop - remain : stop + remain
346353
end
347354
end
348-
last
355+
return last
349356
end
350357

351-
function steprange_last_empty(start::Integer, step, stop)
352-
# empty range has a special representation where stop = start-1
353-
# this is needed to avoid the wrap-around that can happen computing
354-
# start - step, which leads to a range that looks very large instead
355-
# of empty.
358+
function steprange_last_empty(start::Integer, step, stop)::typeof(stop)
359+
# empty range has a special representation where stop = start-1,
360+
# which simplifies arithmetic for Signed numbers
356361
if step > zero(step)
357-
last = start - oneunit(stop-start)
362+
last = start - oneunit(step)
358363
else
359-
last = start + oneunit(stop-start)
364+
last = start + oneunit(step)
360365
end
361-
last
366+
return last
362367
end
363368
# For types where x+oneunit(x) may not be well-defined use the user-given value for stop
364369
steprange_last_empty(start, step, stop) = stop
@@ -388,18 +393,21 @@ UnitRange{Int64}
388393
struct UnitRange{T<:Real} <: AbstractUnitRange{T}
389394
start::T
390395
stop::T
391-
UnitRange{T}(start, stop) where {T<:Real} = new(start, unitrange_last(start,stop))
396+
UnitRange{T}(start::T, stop::T) where {T<:Real} = new(start, unitrange_last(start, stop))
392397
end
398+
UnitRange{T}(start, stop) where {T<:Real} = UnitRange{T}(convert(T, start), convert(T, stop))
393399
UnitRange(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)
400+
UnitRange(start, stop) = UnitRange(promote(start, stop)...)
394401

395-
unitrange_last(::Bool, stop::Bool) = stop
396-
unitrange_last(start::T, stop::T) where {T<:Integer} =
397-
stop >= start ? stop : convert(T,start-oneunit(start-stop))
398-
unitrange_last(start::T, stop::T) where {T} =
399-
stop >= start ? convert(T,start+floor(stop-start)) :
400-
convert(T,start-oneunit(stop-start))
402+
# if stop and start are integral, we know that their difference is a multiple of 1
403+
unitrange_last(start::Integer, stop::Integer) =
404+
stop >= start ? stop : convert(typeof(stop), start - oneunit(start - stop))
405+
# otherwise, use `floor` as a more efficient way to compute modulus with step=1
406+
unitrange_last(start, stop) =
407+
stop >= start ? convert(typeof(stop), start + floor(stop - start)) :
408+
convert(typeof(stop), start - oneunit(start - stop))
401409

402-
unitrange(x) = UnitRange(x)
410+
unitrange(x::AbstractUnitRange) = UnitRange(x) # convenience conversion for promoting the range type
403411

404412
if isdefined(Main, :Base)
405413
# Constant-fold-able indexing into tuples to functionally expose Base.tail and Base.front
@@ -556,7 +564,7 @@ function LinRange{T}(start, stop, len::Integer) where T
556564
end
557565

558566
function LinRange(start, stop, len::Integer)
559-
T = typeof((stop-start)/len)
567+
T = typeof((zero(stop) - zero(start)) / oneunit(len))
560568
LinRange{T}(start, stop, len)
561569
end
562570

@@ -642,7 +650,7 @@ length(r::AbstractRange) = error("length implementation missing") # catch mistak
642650
size(r::AbstractRange) = (length(r),)
643651

644652
isempty(r::StepRange) =
645-
# steprange_last_empty(r.start, r.step, r.stop) == r.stop
653+
# steprange_last(r.start, r.step, r.stop) == r.stop
646654
(r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
647655
isempty(r::AbstractUnitRange) = first(r) > last(r)
648656
isempty(r::StepRangeLen) = length(r) == 0
@@ -689,9 +697,8 @@ firstindex(::LinRange) = 1
689697
# defined between the relevant types
690698
function checked_length(r::OrdinalRange{T}) where T
691699
s = step(r)
692-
# s != 0, by construction, but avoids the division error later
693700
start = first(r)
694-
if s == zero(s) || isempty(r)
701+
if isempty(r)
695702
return Integer(div(start - start, oneunit(s)))
696703
end
697704
stop = last(r)
@@ -716,9 +723,8 @@ end
716723

717724
function length(r::OrdinalRange{T}) where T
718725
s = step(r)
719-
# s != 0, by construction, but avoids the division error later
720726
start = first(r)
721-
if s == zero(s) || isempty(r)
727+
if isempty(r)
722728
return Integer(div(start - start, oneunit(s)))
723729
end
724730
stop = last(r)
@@ -756,7 +762,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
756762
# (near typemax) for types with known `unsigned` functions
757763
function length(r::OrdinalRange{T}) where T<:bigints
758764
s = step(r)
759-
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
760765
isempty(r) && return zero(T)
761766
diff = last(r) - first(r)
762767
# if |s| > 1, diff might have overflowed, but unsigned(diff)÷s should
@@ -773,7 +778,6 @@ let bigints = Union{Int, UInt, Int64, UInt64, Int128, UInt128}
773778
end
774779
function checked_length(r::OrdinalRange{T}) where T<:bigints
775780
s = step(r)
776-
s == zero(s) && return zero(T) # unreachable, by construction, but avoids the error case here later
777781
isempty(r) && return zero(T)
778782
stop, start = last(r), first(r)
779783
# n.b. !(s isa T)
@@ -800,7 +804,6 @@ let smallints = (Int === Int64 ?
800804
# n.b. !(step isa T)
801805
function length(r::OrdinalRange{<:smallints})
802806
s = step(r)
803-
s == zero(s) && return 0 # unreachable, by construction, but avoids the error case here later
804807
isempty(r) && return 0
805808
return div(Int(last(r)) - Int(first(r)), s) + 1
806809
end
@@ -962,29 +965,30 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
962965
@boundscheck checkbounds(r, s)
963966

964967
if T === Bool
965-
range(first(s) ? first(r) : last(r), length = Integer(last(s)))
968+
return range(first(s) ? first(r) : last(r), length = last(s))
966969
else
967970
f = first(r)
968-
st = oftype(f, f + first(s)-firstindex(r))
969-
return range(st, length=length(s))
971+
start = oftype(f, f + first(s)-firstindex(r))
972+
return range(start, length=length(s))
970973
end
971974
end
972975

973976
function getindex(r::OneTo{T}, s::OneTo) where T
974977
@inline
975978
@boundscheck checkbounds(r, s)
976-
OneTo(T(s.stop))
979+
return OneTo(T(s.stop))
977980
end
978981

979982
function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
980983
@inline
981984
@boundscheck checkbounds(r, s)
982985

983986
if T === Bool
984-
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Integer(last(s)))
987+
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
985988
else
986-
st = oftype(first(r), first(r) + s.start-firstindex(r))
987-
return range(st, step=step(s), length=length(s))
989+
f = first(r)
990+
start = oftype(f, f + s.start-firstindex(r))
991+
return range(start, step=step(s), length=length(s))
988992
end
989993
end
990994

@@ -994,19 +998,22 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
994998

995999
if T === Bool
9961000
if length(s) == 0
997-
return range(first(r), step=step(r), length=0)
1001+
start, len = first(r), 0
9981002
elseif length(s) == 1
9991003
if first(s)
1000-
return range(first(r), step=step(r), length=1)
1004+
start, len = first(r), 1
10011005
else
1002-
return range(first(r), step=step(r), length=0)
1006+
start, len = first(r), 0
10031007
end
10041008
else # length(s) == 2
1005-
return range(last(r), step=step(r), length=1)
1009+
start, len = last(r), 1
10061010
end
1011+
return range(start, step=step(r); length=len)
10071012
else
1008-
st = oftype(r.start, r.start + (first(s)-1)*step(r))
1009-
return range(st, step=step(r)*step(s), length=length(s))
1013+
f = r.start
1014+
st = r.step
1015+
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
1016+
return range(start; step=st*step(s), length=length(s))
10101017
end
10111018
end
10121019

@@ -1235,7 +1242,7 @@ end
12351242
issubset(r::OneTo, s::OneTo) = r.stop <= s.stop
12361243

12371244
issubset(r::AbstractUnitRange{<:Integer}, s::AbstractUnitRange{<:Integer}) =
1238-
isempty(r) || first(r) >= first(s) && last(r) <= last(s)
1245+
isempty(r) || (first(r) >= first(s) && last(r) <= last(s))
12391246

12401247
## linear operations on ranges ##
12411248

stdlib/Dates/src/Dates.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ for more information.
3232
"""
3333
module Dates
3434

35-
import Base: ==, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
35+
import Base: ==, isless, div, fld, mod, rem, gcd, lcm, +, -, *, /, %, broadcast
3636
using Printf: @sprintf
3737

3838
using Base.Iterators

stdlib/Dates/src/periods.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ default(p::Union{T,Type{T}}) where {T<:TimePeriod} = T(0)
7070

7171
(-)(x::P) where {P<:Period} = P(-value(x))
7272
==(x::P, y::P) where {P<:Period} = value(x) == value(y)
73-
==(x::Period, y::Period) = (==)(promote(x, y)...)
7473
Base.isless(x::P, y::P) where {P<:Period} = isless(value(x), value(y))
75-
Base.isless(x::Period, y::Period) = isless(promote(x, y)...)
7674

7775
# Period Arithmetic, grouped by dimensionality:
7876
for op in (:+, :-, :lcm, :gcd)
@@ -97,6 +95,11 @@ end
9795
(*)(A::Period, B::AbstractArray) = Broadcast.broadcast_preserving_zero_d(*, A, B)
9896
(*)(A::AbstractArray, B::Period) = Broadcast.broadcast_preserving_zero_d(*, A, B)
9997

98+
for op in (:(==), :isless, :/, :rem, :mod, :lcm, :gcd)
99+
@eval ($op)(x::Period, y::Period) = ($op)(promote(x, y)...)
100+
end
101+
div(x::Period, y::Period, r::RoundingMode) = div(promote(x, y)..., r)
102+
100103
# intfuncs
101104
Base.gcdx(a::T, b::T) where {T<:Period} = ((g, x, y) = gcdx(value(a), value(b)); return T(g), x, y)
102105
Base.abs(a::T) where {T<:Period} = T(abs(value(a)))

stdlib/Dates/src/ranges.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ end
2424
Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.stop, r.step) + 1
2525
# Period ranges hook into Int64 overflow detection
2626
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))
27+
Base.checked_length(r::StepRange{<:Period}) = Base.checked_length(StepRange(value(r.start), value(r.step), value(r.stop)))
2728

28-
# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
29+
# Overload Base.steprange_last because `step::Period` may be a variable amount of time (e.g. for Month and Year)
2930
function Base.steprange_last(start::T, step, stop) where T<:TimeType
30-
if isa(step,AbstractFloat)
31+
if isa(step, AbstractFloat)
3132
throw(ArgumentError("StepRange should not be used with floating point"))
3233
end
3334
z = zero(step)
@@ -47,7 +48,7 @@ function Base.steprange_last(start::T, step, stop) where T<:TimeType
4748
last = stop - remain
4849
end
4950
end
50-
last
51+
return last
5152
end
5253

5354
import Base.in

0 commit comments

Comments
 (0)