Skip to content

Commit 1e6542b

Browse files
authored
Sum for RangeCumsum (#216)
* Sum for RangeCumsum * Don't coerce types in _half
1 parent 2b97036 commit 1e6542b

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

src/cumsum.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ Base.parent(r::RangeCumsum) = r.range
1515
==(a::RangeCumsum, b::RangeCumsum) = a.range == b.range
1616
BroadcastStyle(::Type{<:RangeCumsum{<:Any,RR}}) where RR = BroadcastStyle(RR)
1717

18-
_getindex(r::AbstractUnitRange{<:Integer}, k) = k * (2first(r) + k - 1) ÷ 2
18+
_half(x::Integer) = x ÷ 2
19+
_half(x) = x / 2
20+
21+
function _getindex(r::AbstractRange{<:Real}, k)
22+
v = first(r)
23+
s = step(r)
24+
_half(k * (2v - s + s*k))
25+
end
1926
Base.@propagate_inbounds _getindex(r::AbstractRange, k) = sum(r[range(firstindex(r), length=k)])
2027

2128
Base.@propagate_inbounds function getindex(c::RangeCumsum{<:Any,<:AbstractRange}, k::Integer)
@@ -33,6 +40,13 @@ last(r::RangeCumsum) = sum(r.range)
3340
diff(r::RangeCumsum) = r.range[firstindex(r)+1:end]
3441
isempty(r::RangeCumsum) = isempty(r.range)
3542

43+
function Base.sum(r::RangeCumsum{<:Real})
44+
N = length(r)
45+
v = first(r)
46+
s = step(r.range)
47+
_half((2v-s)*(N*(N+1)÷2) + s*(N*(N+1)*(2N+1)÷6))
48+
end
49+
3650
union(a::RangeCumsum{<:Any,<:OneTo}, b::RangeCumsum{<:Any,<:OneTo}) =
3751
RangeCumsum(OneTo(max(last(a.range), last(b.range))))
3852

test/test_cumsum.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,31 @@ using ArrayLayouts, Test
44

55
include("infinitearrays.jl")
66

7+
cmpop(p) = isinteger(real(first(p))) && isinteger(real(step(p))) ? (==) : ()
8+
79
@testset "RangeCumsum" begin
8-
@testset for p in (Base.OneTo(5), 2:5, 2:2:6, 6:-2:1, -1.0:3.0:5.0, (-1.0:3.0:5.0)*im,
9-
Base.IdentityUnitRange(4:6))
10+
@testset for p in Any[Base.OneTo(5), 2:5, 2:2:6, 6:-2:1, Int8(2):Int8(5),
11+
UnitRange(2.5, 8.5),
12+
-1.0:1.0:10.0, -1.2:1.5:10.0,
13+
(2:5)*im, (-1:3:5)*im, (-1.0:3.0:5.0)*im, (-1.2:3.0:5.2)*(1+im),
14+
Base.IdentityUnitRange(4:6)]
15+
1016
r = RangeCumsum(p)
1117
@test parent(r) == p
1218
@test r == r
19+
cmp = cmpop(p)
20+
if eltype(r) <: Complex
21+
@test sum(r) isa Complex{promote_type(Int, real(eltype(r)))}
22+
end
23+
@test cmp(sum(r), sum(i for i in r))
1324
if axes(r,1) isa Base.OneTo
14-
@test r == cumsum(p)
15-
@test r .+ 1 == cumsum(p) .+ 1
25+
@test cmp(r, cumsum(p))
26+
@test cmp(r .+ 1, cumsum(p) .+ 1)
1627
@test r[Base.OneTo(3)] == r[1:3]
1728
@test @view(r[Base.OneTo(3)]) === r[Base.OneTo(3)] == r[1:3]
1829
@test @view(r[Base.OneTo(3)]) isa RangeCumsum
19-
@test diff(r) == diff(Vector(r))
20-
@test -r == -Vector(r)
30+
@test cmp(diff(r),diff(Vector(r)))
31+
@test cmp(-r, -Vector(r))
2132
end
2233
@test diff(r) == p[firstindex(p)+1:end]
2334
@test last(r) == r[end] == sum(p)
@@ -48,7 +59,7 @@ include("infinitearrays.jl")
4859
@test r * n isa RangeCumsum
4960
@test r * n w * n
5061
end
51-
for p in (Base.OneTo(4), -4:4, -4:2:4, -1.0:3.0:5.0)
62+
@testset for p in (Base.OneTo(4), -4:4, -4:2:4, -1.0:3.0:5.0)
5263
r = RangeCumsum(p)
5364
test_broadcast(3, r)
5465
test_broadcast(3.5, r)

0 commit comments

Comments
 (0)