Skip to content

Commit 7e30f50

Browse files
committed
added tests for pmapsum and pmapreduce and others
1 parent e1a0e21 commit 7e30f50

File tree

3 files changed

+208
-40
lines changed

3 files changed

+208
-40
lines changed

src/ParallelUtilities.jl

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ whichproc,newprocrange,whichproc,procidrange,
88
indexinsplitproduct,procid_and_index,
99
extremadims,extrema_commonlastdim,
1010
workersactive,nworkersactive,workerrank,
11-
pmapsum,pmapreduce,pmap_onebatchperworker,
12-
getnodes,gethostnames,getnprocs_node
11+
nodenames,gethostnames,nprocs_node,
12+
pmapsum,pmapreduce,pmap_onebatchperworker
1313

1414
# The fundamental iterator that behaves like an Iterator.ProductIterator
1515

@@ -66,6 +66,9 @@ function _cumprod(n::Int,tl::Tuple)
6666
(n,_cumprod(n*first(tl),Base.tail(tl))...)
6767
end
6868

69+
@inline ntasks(tl::Tuple) = prod(map(length,tl))
70+
@inline ntasks(ps::ProductSplit) = ntasks(ps.iterators)
71+
6972
function ProductSplit(iterators::NTuple{N,Q},np::Int,p::Int) where {N,Q<:AbstractRange}
7073
T = NTuple{N,eltype(Q)}
7174
len = Base.Iterators._prod_size(iterators)
@@ -354,14 +357,6 @@ end
354357

355358
###################################################################################################
356359

357-
function workerrank()
358-
rank = 1
359-
if myid() in workers()
360-
rank = myid()-minimum(workers())+1
361-
end
362-
return rank
363-
end
364-
365360
function evenlyscatterproduct(num_tasks::Integer,np=nworkers(),procid=workerrank())
366361
evenlyscatterproduct((1:num_tasks,),np,procid)
367362
end
@@ -449,46 +444,63 @@ function procid_and_index(iterators::Tuple,val::Tuple,np::Integer)
449444
return procid,index
450445
end
451446

452-
workersactive(arr) = workers()[1:min(length(arr),nworkers())]
453-
454-
workersactive(arrs...) = workersactive(Iterators.product(arrs...))
447+
function workerrank()
448+
rank = 1
449+
if myid() in workers()
450+
rank = myid()-minimum(workers())+1
451+
end
452+
return rank
453+
end
455454

456-
nworkersactive(args...) = length(workersactive(args...))
455+
@inline function nworkersactive(iterators::Tuple)
456+
nt = ntasks(iterators)
457+
nw = nworkers()
458+
nt <= nw ? nt : nw
459+
end
460+
@inline nworkersactive(ps::ProductSplit) = nworkersactive(ps.iterators)
461+
@inline nworkersactive(args...) = nworkersactive(args)
462+
@inline workersactive(iterators::Tuple) = workers()[1:nworkersactive(iterators)]
463+
@inline workersactive(ps::ProductSplit) = workersactive(ps.iterators)
464+
@inline workersactive(args...) = workersactive(args)
457465

458-
function gethostnames(procs_used=workers())
466+
function gethostnames(procs_used = workers())
459467
hostnames = Vector{String}(undef,length(procs_used))
460468
@sync for (ind,p) in enumerate(procs_used)
461469
@async hostnames[ind] = @fetchfrom p Libc.gethostname()
462470
end
463471
return hostnames
464472
end
465473

466-
getnodes(hostnames::Vector{String}) = unique(hostnames)
467-
getnodes(procs_used::Vector{<:Integer}=workers()) = getnodes(gethostnames(procs_used))
474+
nodenames(hostnames::Vector{String}) = unique(hostnames)
475+
nodenames(procs_used::Vector{<:Integer} = workers()) = nodenames(gethostnames(procs_used))
468476

469-
function getnprocs_node(hostnames::Vector{String})
470-
nodes = getnodes(hostnames)
471-
getnprocs_node(hostnames,nodes)
477+
function nprocs_node(hostnames::Vector{String})
478+
nodes = nodenames(hostnames)
479+
nprocs_node(hostnames,nodes)
472480
end
473481

474-
function getnprocs_node(hostnames::Vector{String},nodes::Vector{String})
482+
function nprocs_node(hostnames::Vector{String},nodes::Vector{String})
475483
Dict(node=>count(isequal(node),hostnames) for node in nodes)
476484
end
477485

478-
getnprocs_node(procs_used::Vector{<:Integer}=workers()) = getnprocs_node(gethostnames(procs_used))
486+
nprocs_node(procs_used::Vector{<:Integer} = workers()) = nprocs_node(gethostnames(procs_used))
487+
488+
############################################################################################
489+
# pmapsum and pmapreduce
490+
############################################################################################
479491

480492
# This function does not sort the values, so it might be faster
481-
function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
493+
function pmapsum(::Type{T},f::Function,iterators::Tuple,args...;kwargs...) where {T}
482494

483-
procs_used = workersactive(iterable)
495+
procs_used = workersactive(iterators)
484496

485497
num_workers = length(procs_used);
486498
hostnames = gethostnames(procs_used);
487-
nodes = getnodes(hostnames);
499+
nodes = nodenames(hostnames);
488500
procid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
489501

490-
nprocs_node = getnprocs_node(procs_used)
491-
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),procid_node)
502+
nprocs_node_dict = nprocs_node(procs_used)
503+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node_dict[node]),procid_node)
492504
for (node,procid_node) in zip(nodes,procid_rank0_on_node))
493505

494506
# Worker at which final reduction takes place
@@ -502,9 +514,9 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
502514
@async begin
503515

504516
node_remotechannel = node_channels[node]
505-
np_node = nprocs_node[node]
517+
np_node = nprocs_node_dict[node]
506518

507-
iterable_on_proc = evenlyscatterproduct(iterable,num_workers,rank)
519+
iterable_on_proc = evenlyscatterproduct(iterators,num_workers,rank)
508520
@spawnat p put!(node_remotechannel,
509521
f(iterable_on_proc,args...;kwargs...))
510522

@@ -523,31 +535,39 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
523535
end
524536
end
525537

526-
return result :: T
538+
return result
527539
end
528540

529-
function pmapsum(f::Function,iterable,args...;kwargs...)
541+
function pmapsum(f::Function,iterable::Tuple,args...;kwargs...)
530542
pmapsum(Any,f,iterable,args...;kwargs...)
531543
end
532544

545+
function pmapsum(f::Function,itp::Iterators.ProductIterator,args...;kwargs...)
546+
pmapsum(Any,f,itp.iterators,args...;kwargs...)
547+
end
548+
549+
function pmapsum(f::Function,iterable,args...;kwargs...)
550+
pmapsum(Any,f,(iterable,),args...;kwargs...)
551+
end
552+
533553
# Store the processor id with the value
534554
struct pval{T}
535555
p :: Int
536556
parent :: T
537557
end
538558

539559
function pmapreduce(::Type{T},fmap::Function,freduce::Function,
540-
iterable,args...;kwargs...) where {T}
560+
iterable::Tuple,args...;kwargs...) where {T}
541561

542562
procs_used = workersactive(iterable)
543563

544564
num_workers = length(procs_used);
545565
hostnames = gethostnames(procs_used);
546-
nodes = getnodes(hostnames);
566+
nodes = nodenames(hostnames);
547567
procid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
548568

549-
nprocs_node = getnprocs_node(procs_used)
550-
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),procid_node)
569+
nprocs_node_dict = nprocs_node(procs_used)
570+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node_dict[node]),procid_node)
551571
for (node,procid_node) in zip(nodes,procid_rank0_on_node))
552572

553573
# Worker at which final reduction takes place
@@ -561,7 +581,7 @@ function pmapreduce(::Type{T},fmap::Function,freduce::Function,
561581
@async begin
562582

563583
node_remotechannel = node_channels[node]
564-
np_node = nprocs_node[node]
584+
np_node = nprocs_node_dict[node]
565585

566586
iterable_on_proc = evenlyscatterproduct(iterable,num_workers,rank)
567587
@spawnat p put!(node_remotechannel,
@@ -588,15 +608,27 @@ function pmapreduce(::Type{T},fmap::Function,freduce::Function,
588608
end
589609
end
590610

591-
return result :: T
611+
return result
592612
end
593613

594-
function pmapreduce(fmap::Function,freduce::Function,iterable,args...;kwargs...)
614+
function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwargs...)
595615
pmapreduce(Any,fmap,freduce,iterable,args...;kwargs...)
596616
end
597617

598-
function pmap_onebatchperworker(f::Function,iterable,args...;kwargs...)
618+
function pmapreduce(fmap::Function,freduce::Function,
619+
itp::Iterators.ProductIterator,args...;kwargs...)
620+
pmapreduce(Any,fmap,freduce,itp.iterators,args...;kwargs...)
621+
end
599622

623+
function pmapreduce(fmap::Function,freduce::Function,iterable,args...;kwargs...)
624+
pmapreduce(Any,fmap,freduce,(iterable,),args...;kwargs...)
625+
end
626+
627+
############################################################################################
628+
# pmap in batches without reduction
629+
############################################################################################
630+
631+
function pmap_onebatchperworker(f::Function,iterable::Tuple,args...;kwargs...)
600632
procs_used = workersactive(iterable)
601633
num_workers = get(kwargs,:num_workers,length(procs_used))
602634
if num_workers<length(procs_used)
@@ -614,4 +646,12 @@ function pmap_onebatchperworker(f::Function,iterable,args...;kwargs...)
614646
return futures
615647
end
616648

649+
function pmap_onebatchperworker(f::Function,itp::Iterators.ProductIterator,args...;kwargs...)
650+
pmap_onebatchperworker(f,itp.iterators,args...;kwargs...)
651+
end
652+
653+
function pmap_onebatchperworker(f::Function,iterable,args...;kwargs...)
654+
pmap_onebatchperworker(f,(iterable,),args...;kwargs...)
655+
end
656+
617657
end # module

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
3+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
34
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using ParallelUtilities,Test,Distributed
22

3+
addprocs(2)
4+
@everywhere begin
5+
using Pkg
6+
Pkg.activate(".")
7+
using ParallelUtilities
8+
end
9+
310
@testset "ProductSplit" begin
411

512
function split_across_processors_iterators(arr₁::Base.Iterators.ProductIterator,num_procs,proc_id)
@@ -22,9 +29,12 @@ using ParallelUtilities,Test,Distributed
2229
@testset "Constructor" begin
2330

2431
function checkPSconstructor(iters,npmax=10)
32+
ntasks_total = prod(map(length,iters))
2533
for np = 1:npmax, p = 1:np
2634
ps = ProductSplit(iters,np,p)
2735
@test collect(ps) == collect(split_product_across_processors_iterators(iters,np,p))
36+
@test ParallelUtilities.ntasks(ps) == ntasks_total
37+
@test ParallelUtilities.ntasks(ps.iterators) == ntasks_total
2838
end
2939

3040
@test_throws ParallelUtilities.ProcessorNumberError ProductSplit(iters,npmax,npmax+1)
@@ -48,6 +58,23 @@ using ParallelUtilities,Test,Distributed
4858
iters = (10:-1:10,6:-2:0)
4959
@test_throws ParallelUtilities.DecreasingIteratorError ProductSplit(iters,3,2)
5060
end
61+
62+
@testset "empty" begin
63+
iters = (1:1,)
64+
ps = ProductSplit(iters,10,2)
65+
@test isempty(ps)
66+
@test length(ps) == 0
67+
end
68+
69+
@testset "first and last ind" begin
70+
iters = (1:10,)
71+
ps = ProductSplit(iters,2,1)
72+
@test ps.firstind == 1
73+
@test ps.lastind == div(length(iters[1]),2)
74+
ps = ProductSplit(iters,2,2)
75+
@test ps.firstind == div(length(iters[1]),2) + 1
76+
@test ps.lastind == length(iters[1])
77+
end
5178
end
5279

5380
@testset "extrema" begin
@@ -179,4 +206,104 @@ end
179206
@test a == b
180207
@test a <= b
181208
end
182-
end
209+
end
210+
211+
@testset "utilities" begin
212+
@testset "workerrank" begin
213+
for (rank,workerid) in enumerate(workers())
214+
@test @fetchfrom workerid myid() == workerid
215+
@test @fetchfrom workerid workerrank() == rank
216+
end
217+
end
218+
219+
@testset "workers active" begin
220+
@test nworkersactive((1:1,)) == 1
221+
@test nworkersactive((1:2,)) == min(2,nworkers())
222+
@test nworkersactive((1:1,1:2)) == min(2,nworkers())
223+
@test nworkersactive(1:2) == min(2,nworkers())
224+
@test nworkersactive(1:1,1:2) == min(2,nworkers())
225+
@test nworkersactive((1:nworkers()+1,)) == 2
226+
@test nworkersactive(1:nworkers()+1) == 2
227+
@test workersactive((1:1,)) == workers()[1:1]
228+
@test workersactive(1:1) == workers()[1:1]
229+
@test workersactive(1:1,1:1) == workers()[1:1]
230+
@test workersactive((1:2,)) == workers()[1:min(2,nworkers())]
231+
@test workersactive((1:1,1:2)) == workers()[1:min(2,nworkers())]
232+
@test workersactive(1:1,1:2) == workers()[1:min(2,nworkers())]
233+
@test workersactive((1:nworkers()+1,)) == workers()
234+
@test workersactive(1:nworkers()+1) == workers()
235+
236+
ps = ProductSplit((1:10,),nworkers(),1)
237+
@test nworkersactive(ps) == min(10,nworkers())
238+
end
239+
240+
@testset "hostnames" begin
241+
hostnames = gethostnames()
242+
nodes = unique(hostnames)
243+
@test hostnames == [@fetchfrom p Libc.gethostname() for p in workers()]
244+
@test nodenames() == nodes
245+
@test nodenames(hostnames) == nodes
246+
npnodes = Dict(hostnames[1]=>nworkers())
247+
@test nprocs_node(hostnames,nodes) == npnodes
248+
@test nprocs_node(hostnames) == npnodes
249+
@test nprocs_node() == npnodes
250+
end
251+
end
252+
253+
@testset "pmap and reduce" begin
254+
255+
@testset "pmapsum" begin
256+
257+
@testset "worker id" begin
258+
@test pmapsum(x->workerrank(),1:nworkers()) == sum(1:nworkers())
259+
@test pmapsum(x->myid(),1:nworkers()) == sum(workers())
260+
end
261+
262+
@testset "one iterator" begin
263+
rng = 1:100
264+
@test pmapsum(x->sum(y[1] for y in x),rng) == sum(rng)
265+
@test pmapsum(x->sum(y[1] for y in x),Iterators.product(rng)) == sum(rng)
266+
@test pmapsum(x->sum(y[1] for y in x),(rng,)) == sum(rng)
267+
end
268+
269+
@testset "array" begin
270+
@test pmapsum(x->ones(2),1:nworkers()) == ones(2).*nworkers()
271+
end
272+
273+
@testset "stepped iterator" begin
274+
rng = 1:5:100
275+
@test pmapsum(x->sum(y[1] for y in x),rng) == sum(rng)
276+
end
277+
278+
@testset "two iterators" begin
279+
iters = (1:100,1:2)
280+
@test pmapsum(x->sum(y[1] for y in x),iters) == sum(iters[1])*length(iters[2])
281+
end
282+
end
283+
284+
@testset "pmapreduce" begin
285+
@testset "sum" begin
286+
@test pmapreduce(x->myid(),sum,1:nworkers()) == sum(workers())
287+
@test pmapreduce(x->myid(),sum,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
288+
end
289+
290+
@testset "concatenation" begin
291+
@test pmapreduce(x->ones(2),x->vcat(x...),1:nworkers()) == ones(2*nworkers())
292+
@test pmapreduce(x->ones(2),x->hcat(x...),1:nworkers()) == ones(2,nworkers())
293+
end
294+
295+
@testset "sorting" begin
296+
@test pmapreduce(x->ones(2)*workerrank(),x->vcat(x...),1:nworkers()) ==
297+
vcat((ones(2).*i for i=1:nworkers())...)
298+
end
299+
300+
@testset "worker id" begin
301+
@test pmapreduce(x->workerrank(),x->vcat(x...),1:nworkers()) == collect(1:nworkers())
302+
@test pmapreduce(x->myid(),x->vcat(x...),1:nworkers()) == workers()
303+
end
304+
end
305+
end
306+
307+
308+
309+
rmprocs(workers())

0 commit comments

Comments
 (0)