Skip to content

Commit bf061ab

Browse files
committed
removed ambiguity in _first and _last, added tests
1 parent 8d031cf commit bf061ab

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

src/ParallelUtilities.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ struct ProductSplit{T,N,Q}
7373
end
7474
Base.eltype(::ProductSplit{T}) where {T} = T
7575

76-
function _cumprod(len)
76+
function _cumprod(len::Tuple)
7777
(0,_cumprod(first(len),Base.tail(len))...)
7878
end
7979

80-
_cumprod(::Int,::Tuple{}) = ()
80+
@inline _cumprod(::Int,::Tuple{}) = ()
8181
function _cumprod(n::Int,tl::Tuple)
8282
(n,_cumprod(n*first(tl),Base.tail(tl))...)
8383
end
@@ -107,7 +107,7 @@ end
107107
@boundscheck (1 <= ind <= length(first(t))) || throw(BoundsError(first(t),ind))
108108
(@inbounds first(t)[ind],_first(Base.tail(t),rest...)...)
109109
end
110-
@inline _first(::Tuple{},rest...) = ()
110+
@inline _first(::Tuple{}) = ()
111111

112112
@inline Base.@propagate_inbounds function Base.last(ps::ProductSplit)
113113
isempty(ps) ? nothing : _last(ps.iterators,childindex(ps,ps.lastind)...)
@@ -117,7 +117,7 @@ end
117117
@boundscheck (1 <= ind <= length(first(t))) || throw(BoundsError(first(t),ind))
118118
(@inbounds first(t)[ind],_last(Base.tail(t),rest...)...)
119119
end
120-
@inline _last(::Tuple{},rest...) = ()
120+
@inline _last(::Tuple{}) = ()
121121

122122
@inline Base.length(ps::ProductSplit) = ps.lastind - ps.firstind + 1
123123
@inline Base.lastindex(ps::ProductSplit) = ps.lastind - ps.firstind + 1
@@ -371,7 +371,7 @@ _infullrange(val::T,ps::ProductSplit{T}) where {T} = _infullrange(val,ps.iterato
371371
function _infullrange(val,t::Tuple)
372372
first(val) in first(t) && _infullrange(Base.tail(val),Base.tail(t))
373373
end
374-
_infullrange(::Tuple{},::Tuple{}) = true
374+
@inline _infullrange(::Tuple{},::Tuple{}) = true
375375

376376
# This struct is just a wrapper to flip the tuples before comparing
377377
struct ReverseLexicographicTuple{T}

test/runtests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636
@test collect(ps) == collect(split_product_across_processors_iterators(iters,np,p))
3737
@test ntasks(ps) == ntasks_total
3838
@test ntasks(ps.iterators) == ntasks_total
39+
@test eltype(ps) == Tuple{map(eltype,iters)...}
3940
end
4041

4142
@test_throws ParallelUtilities.ProcessorNumberError ProductSplit(iters,npmax,npmax+1)
@@ -45,6 +46,13 @@ end
4546
@test_throws ArgumentError ProductSplit((),2,1)
4647
end
4748

49+
@testset "cumprod" begin
50+
@test ParallelUtilities._cumprod(1,()) == ()
51+
@test ParallelUtilities._cumprod(1,(2,)) == (1,)
52+
@test ParallelUtilities._cumprod(1,(2,3)) == (1,2)
53+
@test ParallelUtilities._cumprod(1,(2,3,4)) == (1,2,6)
54+
end
55+
4856
@testset "1D" begin
4957
iters = (1:10,)
5058
checkPSconstructor(iters)
@@ -92,6 +100,9 @@ end
92100

93101
@testset "firstlast" begin
94102
@testset "first" begin
103+
104+
@test ParallelUtilities._first(()) == ()
105+
95106
for iters in [(1:10,),(1:10,4:6),(1:10,4:6,1:4),(1:2:10,4:1:6)],
96107
np=1:5ntasks(iters)
97108

@@ -104,6 +115,9 @@ end
104115
@test first(ps) === nothing
105116
end
106117
@testset "last" begin
118+
119+
@test ParallelUtilities._last(()) == ()
120+
107121
for iters in [(1:10,),(1:10,4:6),(1:10,4:6,1:4),(1:2:10,4:1:6)],
108122
np=1:5ntasks(iters)
109123

@@ -144,6 +158,8 @@ end
144158
end
145159

146160
@testset "extremadims" begin
161+
ps = ProductSplit((1:10,),2,1)
162+
@test ParallelUtilities._extremadims(ps,1,()) == ()
147163
for iters in [(1:10,),(1:10,4:6),(1:10,4:6,1:4),(1:2:10,4:1:6)]
148164
dims = length(iters)
149165
for np = 1:5ntasks(iters), proc_id = 1:np
@@ -190,6 +206,8 @@ end
190206
for iters in [(1:10,),(1:10,4:6),(1:10,4:6,1:4),(1:2:10,4:1:6)]
191207
checkifpresent(iters)
192208
end
209+
210+
@test ParallelUtilities._infullrange((),())
193211
end
194212

195213
@testset "evenlyscatterproduct" begin
@@ -268,6 +286,12 @@ end
268286
end
269287

270288
@testset "getindex" begin
289+
290+
@test ParallelUtilities._getindex((),1) == ()
291+
@test ParallelUtilities._getindex((),1,2) == ()
292+
293+
@test ParallelUtilities.childindex((),1) == (1,)
294+
271295
for iters in [(1:10,),(1:10,4:6),(1:10,4:6,1:4),(1:2:10,4:1:6)]
272296
for np=1:ntasks(iters),p=1:np
273297
ps = ProductSplit(iters,np,p)
@@ -430,7 +454,9 @@ end
430454
@testset "worker id" begin
431455
@test pmapsum(x->workerrank(),1:nworkers()) == sum(1:nworkers())
432456
@test pmapsum(x->workerrank(),(1:nworkers(),)) == sum(1:nworkers())
457+
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers())) == sum(1:nworkers())
433458
@test pmapsum(x->workerrank(),(1:nworkers(),1:1)) == sum(1:nworkers())
459+
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers(),1:1)) == sum(1:nworkers())
434460
@test pmapsum(x->myid(),1:nworkers()) == sum(workers())
435461
end
436462

@@ -465,6 +491,10 @@ end
465491
iterable = 1:100
466492
res = pmapsum_elementwise(identity,iterable)
467493
@test res == sum(iterable)
494+
res = pmapsum_elementwise(identity,Iterators.product(iterable))
495+
@test res == sum(iterable)
496+
res = pmapsum_elementwise(identity,(iterable,))
497+
@test res == sum(iterable)
468498

469499
iterable = 1:100
470500
res = pmapsum_elementwise(x->x^2,iterable)
@@ -483,13 +513,17 @@ end
483513
@testset "sum" begin
484514
@test pmapreduce_commutative(x->myid(),sum,1:nworkers()) == sum(workers())
485515
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),)) == sum(workers())
516+
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers())) == sum(workers())
486517
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),1:1)) == sum(workers())
518+
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers(),1:1)) == sum(workers())
487519
@test pmapreduce_commutative(x->myid(),sum,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
488520
end
489521
@testset "prod" begin
490522
@test pmapreduce_commutative(x->myid(),prod,1:nworkers()) == prod(workers())
491523
@test pmapreduce_commutative(x->myid(),prod,(1:nworkers(),)) == prod(workers())
524+
@test pmapreduce_commutative(x->myid(),prod,Iterators.product(1:nworkers())) == prod(workers())
492525
@test pmapreduce_commutative(x->myid(),prod,(1:nworkers(),1:1)) == prod(workers())
526+
@test pmapreduce_commutative(x->myid(),prod,Iterators.product(1:nworkers(),1:1)) == prod(workers())
493527
end
494528

495529
@testset "errors" begin
@@ -509,6 +543,10 @@ end
509543
@test res == sum(x->x^2,iter)
510544
@test res == pmapsum_elementwise(x->x^2,iter)
511545
@test res == pmapsum(plist->sum(x[1]^2 for x in plist),iter)
546+
res = pmapreduce_commutative_elementwise(x->x^2,sum,(iter,))
547+
@test res == sum(x->x^2,iter)
548+
res = pmapreduce_commutative_elementwise(x->x^2,sum,Iterators.product(iter))
549+
@test res == sum(x->x^2,iter)
512550
end
513551

514552
@testset "errors" begin
@@ -528,7 +566,9 @@ end
528566
@testset "sum" begin
529567
@test pmapreduce(x->myid(),sum,1:nworkers()) == sum(workers())
530568
@test pmapreduce(x->myid(),sum,(1:nworkers(),)) == sum(workers())
569+
@test pmapreduce(x->myid(),sum,Iterators.product(1:nworkers())) == sum(workers())
531570
@test pmapreduce(x->myid(),sum,(1:nworkers(),1:1)) == sum(workers())
571+
@test pmapreduce(x->myid(),sum,Iterators.product(1:nworkers(),1:1)) == sum(workers())
532572
@test pmapreduce(x->myid(),sum,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
533573
end
534574

0 commit comments

Comments
 (0)