Skip to content

Commit 4b74ee7

Browse files
committed
local reduction in pmapsum_commutative
pmapsum_commutative carries out a local reduction on each host, followed by a final reduction across hosts. This is achieved by using a SegmentedSequentialBinaryTree, that is a collection of SequentialBinaryTrees on each host, with an extra tree that connects the hosts where the eventual reduction is carred out. This commit also changes procs_node and nprocs_node to return an OrderedDict instead of a Dict, in order to preserve the sequence of hosts.
1 parent 20925a1 commit 4b74ee7

File tree

8 files changed

+829
-255
lines changed

8 files changed

+829
-255
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "ParallelUtilities"
22
uuid = "fad6cfc8-4f83-11e9-06cc-151124046ad0"
33
authors = ["Jishnu Bhattacharya <[email protected]>"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
7+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
89
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
910
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1011

1112
[compat]
12-
Reexport = "0.2"
13+
DataStructures = "0.17"
1314
ProgressMeter = "1.2"
15+
Reexport = "0.2"
1416
julia = "1.2"
1517

1618
[extras]

src/ParallelUtilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ParallelUtilities
22
using ProgressMeter
3-
3+
using DataStructures
44
using Reexport
55
@reexport using Distributed
66

src/errors.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,4 @@ struct TaskNotPresentError{T,U} <: Exception
1818
end
1919
function Base.showerror(io::IO,err::TaskNotPresentError)
2020
print(io,"could not find the task $(err.task) in the list $(err.t)")
21-
end
22-
23-
struct BinaryTreeError <: Exception
24-
n :: Int
25-
end
26-
function Base.showerror(io::IO,err::BinaryTreeError)
27-
print(io,"attempt to construct a binary tree with $(err.n) children")
2821
end

src/mapreduce.jl

Lines changed: 100 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ end
2525
put!(pipe.selfchannels.out,valT)
2626
end
2727

28-
function indicatemapprogress!(::Nothing) end
29-
function indicatemapprogress!(progress::RemoteChannel)
30-
put!(progress,(true,false))
28+
function indicatemapprogress!(::Nothing,rank) end
29+
function indicatemapprogress!(progress::RemoteChannel,rank)
30+
put!(progress,(true,false,rank))
31+
end
32+
33+
function indicatefailure!(::Nothing,rank) end
34+
function indicatefailure!(progress::RemoteChannel,rank)
35+
put!(progress,(false,false,rank))
3136
end
3237

3338
function mapTreeNode(fmap::Function,iterator,rank,pipe::BranchChannel,
@@ -40,11 +45,11 @@ function mapTreeNode(fmap::Function,iterator,rank,pipe::BranchChannel,
4045
res = fmap(iterator,args...;kwargs...)
4146
maybepvalput!(pipe,rank,res)
4247
put!(pipe.selfchannels.err,false)
48+
indicatemapprogress!(progress,rank)
4349
catch
4450
put!(pipe.selfchannels.err,true)
51+
indicatefailure!(progress,rank)
4552
rethrow()
46-
finally
47-
indicatemapprogress!(progress)
4853
end
4954
end
5055

@@ -59,15 +64,25 @@ struct Unsorted <: Ordering end
5964
function reducedvalue(freduce::Function,rank,
6065
pipe::BranchChannel{Tmap,Tred},::Unsorted) where {Tmap,Tred}
6166

62-
self = take!(pipe.selfchannels.out) :: Tmap
63-
6467
N = nchildren(pipe)
65-
res = if N > 0
68+
if rank > 0
69+
self = take!(pipe.selfchannels.out) :: Tmap
70+
if N > 0
6671
reducechildren = freduce(take!(pipe.childrenchannels.out)::Tred for i=1:N)::Tred
67-
freduce((reducechildren,self)) :: Tred
68-
else
69-
freduce((self,)) :: Tred
72+
res = freduce((reducechildren, self)) :: Tred
73+
elseif N == 0
74+
res = freduce((self,)) :: Tred
75+
end
76+
else
77+
if N > 0
78+
res = freduce(take!(pipe.childrenchannels.out)::Tred for i=1:N)::Tred
79+
elseif N == 0
80+
# N == 0 && rank <= 0
81+
# shouldn't reach this
82+
error("nodes with rank <=0 must have children")
7083
end
84+
end
85+
return res
7186
end
7287

7388
function reducedvalue(freduce::Function,rank,
@@ -95,9 +110,9 @@ function reducedvalue(freduce::Function,rank,
95110
Tred(rank,freduce(value(v) for v in vals))
96111
end
97112

98-
function indicatereduceprogress!(::Nothing) end
99-
function indicatereduceprogress!(progress::RemoteChannel)
100-
put!(progress,(false,true))
113+
function indicatereduceprogress!(::Nothing,rank) end
114+
function indicatereduceprogress!(progress::RemoteChannel,rank)
115+
put!(progress,(false,true,rank))
101116
end
102117

103118
function reduceTreeNode(freduce::Function,rank,pipe::BranchChannel{Tmap,Tred},
@@ -106,8 +121,13 @@ function reduceTreeNode(freduce::Function,rank,pipe::BranchChannel{Tmap,Tred},
106121

107122
# Start by checking if there is any error locally in the map,
108123
# and if there's none then check if there are any errors on the children
109-
anyerr = take!(pipe.selfchannels.err) ||
110-
any(take!(pipe.childrenchannels.err) for i=1:nchildren(pipe))
124+
if rank > 0
125+
anyerr = take!(pipe.selfchannels.err)
126+
else
127+
anyerr = false
128+
end
129+
anyerr = anyerr ||
130+
any(take!(pipe.childrenchannels.err) for i=1:nchildren(pipe))
111131

112132
# Evaluate the reduction only if there's no error
113133
# In either case push the error flag to the parent
@@ -116,15 +136,15 @@ function reduceTreeNode(freduce::Function,rank,pipe::BranchChannel{Tmap,Tred},
116136
res = reducedvalue(freduce,rank,pipe,ifsort) :: Tred
117137
put!(pipe.parentchannels.out,res)
118138
put!(pipe.parentchannels.err,false)
139+
indicatereduceprogress!(progress,rank)
119140
catch e
120141
put!(pipe.parentchannels.err,true)
142+
indicatefailure!(progress,rank)
121143
rethrow()
122-
finally
123-
indicatereduceprogress!(progress)
124144
end
125145
else
126146
put!(pipe.parentchannels.err,true)
127-
indicatereduceprogress!(progress)
147+
indicatefailure!(progress,rank)
128148
end
129149

130150
finalize(pipe)
@@ -147,40 +167,67 @@ function pmapreduceworkers(fmap::Function,freduce::Function,iterators::Tuple,
147167
kwargs...)
148168

149169
num_workers_active = nworkersactive(iterators)
170+
Nmaptotal = num_workers_active
171+
Nreducetotal = length(branches)
172+
extrareducenodes = Nreducetotal - Nmaptotal
173+
174+
Nprogress = Nmaptotal+Nreducetotal
175+
progresschannel = RemoteChannel(()->Channel{Tuple{Bool,Bool,Int}}(
176+
ifelse(showprogress,Nprogress,0)))
177+
progressbar = Progress(Nprogress,1,progressdesc)
150178

151-
nmap,nred = 0,0
152-
progresschannel = RemoteChannel(()->Channel{Tuple{Bool,Bool}}(
153-
ifelse(showprogress,2num_workers_active,0)))
154-
progressbar = Progress(2num_workers_active,1,progressdesc)
155-
156-
# Run the function on each processor and compute the reduction at each node
157179
@sync begin
158-
for (rank,mypipe) in enumerate(branches)
159-
@async begin
160-
p = mypipe.p
180+
181+
for (ind,mypipe) in enumerate(branches)
182+
p = mypipe.p
183+
rank = ind - extrareducenodes
184+
if rank > 0
161185
iterable_on_proc = ProductSplit(iterators,num_workers_active,rank)
162186

163187
@spawnat p mapTreeNode(fmap,iterable_on_proc,rank,mypipe,
164188
ifelse(showprogress,progresschannel,nothing),
165189
args...;kwargs...)
190+
191+
@spawnat p reduceTreeNode(freduce,rank,mypipe,ord,
192+
ifelse(showprogress,progresschannel,nothing))
193+
else
166194
@spawnat p reduceTreeNode(freduce,rank,mypipe,ord,
167195
ifelse(showprogress,progresschannel,nothing))
168196
end
169197
end
170-
198+
171199
if showprogress
172-
for i = 1:2num_workers_active
173-
mapdone,reddone = take!(progresschannel)
174-
nmap += mapdone
175-
nred += reddone
176200

177-
next!(progressbar;showvalues=[(:map,nmap),(:reduce,nred)])
201+
mapdone,reducedone = 0,0
202+
203+
for i = 1:Nprogress
204+
mapflag,redflag,rank = take!(progresschannel)
205+
# both flags are false in case of an error
206+
mapflag || redflag || break
207+
208+
mapdone += mapflag
209+
reducedone += redflag
210+
211+
if mapdone != Nmaptotal && reducedone != Nreducetotal
212+
showvalues = [
213+
(:map,string(mapdone)*"/"*string(Nmaptotal)),
214+
(:reduce,string(reducedone)*"/"*string(Nreducetotal))
215+
]
216+
217+
elseif reducedone != Nreducetotal
218+
showvalues = [
219+
(:reduce,string(reducedone)*"/"*string(Nreducetotal))
220+
]
221+
else
222+
showvalues = []
223+
end
224+
225+
next!(progressbar;showvalues=showvalues)
178226
end
179-
finish!(progressbar)
180227
end
181228
end
182229

183-
return_unless_error(topnode(tree,branches))
230+
return_unless_error(topbranch(tree,branches))
184231
end
185232

186233
"""
@@ -227,14 +274,6 @@ julia> pmapreduce_commutative(x->ones(2), x->hcat(x...), 1:4)
227274
1.0 1.0 1.0 1.0
228275
1.0 1.0 1.0 1.0
229276
230-
julia> pmapreduce_commutative(x->(sleep(myid());ones(2)), x->hcat(x...), 1:4, showprogress=true, progressdesc="Progress : ")
231-
Progress : 100%|████████████████████████████████████████| Time: 0:00:05
232-
map: 4
233-
reduce: 4
234-
2×4 Array{Float64,2}:
235-
1.0 1.0 1.0 1.0
236-
1.0 1.0 1.0 1.0
237-
238277
julia> pmapreduce_commutative(x->ones(2), Vector{Int64}, x->hcat(x...), Matrix{Int64}, 1:4)
239278
2×4 Array{Int64,2}:
240279
1 1 1 1
@@ -243,12 +282,12 @@ julia> pmapreduce_commutative(x->ones(2), Vector{Int64}, x->hcat(x...), Matrix{I
243282
244283
See also: [`pmapreduce_commutative_elementwise`](@ref), [`pmapreduce`](@ref), [`pmapsum`](@ref)
245284
"""
246-
function pmapreduce_commutative(fmap::Function,::Type{Tmap},
247-
freduce::Function,::Type{Tred},iterators::Tuple,args...;
248-
kwargs...) where {Tmap,Tred}
285+
function pmapreduce_commutative(fmap::Function,Tmap::Type,
286+
freduce::Function,Tred::Type,iterators::Tuple,args...;
287+
kwargs...)
249288

250289
tree,branches = createbranchchannels(Tmap,Tred,iterators,
251-
SequentialBinaryTree)
290+
SegmentedSequentialBinaryTree)
252291
pmapreduceworkers(fmap,freduce,iterators,tree,
253292
branches,Unsorted(),args...;kwargs...)
254293
end
@@ -259,8 +298,9 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,
259298
pmapreduce_commutative(fmap,Any,freduce,Any,iterators,args...;kwargs...)
260299
end
261300

262-
function pmapreduce_commutative(fmap::Function,::Type{Tmap},freduce::Function,::Type{Tred},
263-
iterable,args...;kwargs...) where {Tmap,Tred}
301+
function pmapreduce_commutative(fmap::Function,Tmap::Type,
302+
freduce::Function,Tred::Type,
303+
iterable,args...;kwargs...)
264304
pmapreduce_commutative(fmap,Tmap,freduce,Tred,(iterable,),args...;kwargs...)
265305
end
266306

@@ -310,22 +350,16 @@ julia> pmapreduce_commutative_elementwise(x->x^2,prod,1:5)
310350
julia> pmapreduce_commutative_elementwise((x,y)->x+y,sum,(1:2,1:2))
311351
12
312352
313-
julia> pmapreduce_commutative_elementwise(x->(sleep(myid());x^2), prod, 1:5, showprogress=true, progressdesc = "Progress : ")
314-
Progress : 100%|██████████████████████████████████████| Time: 0:00:05
315-
map: 4
316-
reduce: 4
317-
14400
318-
319353
julia> pmapreduce_commutative_elementwise(x->x^2,Int,prod,Float64,1:5)
320354
14400.0
321355
```
322356
323357
See also: [`pmapsum_commutative_elementwise`](@ref), [`pmapreduce_commutative`](@ref)
324358
"""
325-
function pmapreduce_commutative_elementwise(fmap::Function,::Type{Tmap},
326-
freduce::Function,::Type{Tred},iterable,args...;
359+
function pmapreduce_commutative_elementwise(fmap::Function,Tmap::Type,
360+
freduce::Function,Tred::Type,iterable,args...;
327361
showprogress::Bool = false, progressdesc = "Progress in pmapreduce : ",
328-
kwargs...) where {Tmap,Tred}
362+
kwargs...)
329363

330364
pmapreduce_commutative(
331365
plist->freduce((fmap(x...,args...;kwargs...) for x in plist)),
@@ -380,14 +414,6 @@ julia> pmapsum(x->ones(2), 1:4)
380414
4.0
381415
4.0
382416
383-
julia> pmapsum(x->(sleep(myid());ones(2)), 1:4, showprogress=true, progressdesc = "Progress : ")
384-
Progress : 100%|███████████████████████████████| Time: 0:00:05
385-
map: 4
386-
reduce: 4
387-
2-element Array{Float64,1}:
388-
4.0
389-
4.0
390-
391417
julia> pmapsum(x->ones(2), Vector{Int64}, 1:4)
392418
2-element Array{Int64,1}:
393419
4
@@ -396,7 +422,7 @@ julia> pmapsum(x->ones(2), Vector{Int64}, 1:4)
396422
397423
See also: [`pmapreduce`](@ref), [`pmapreduce_commutative`](@ref)
398424
"""
399-
function pmapsum(fmap::Function,::Type{T},iterable,args...;kwargs...) where {T}
425+
function pmapsum(fmap::Function,T::Type,iterable,args...;kwargs...)
400426
pmapreduce_commutative(fmap,T,sum,T,iterable,args...;
401427
progressdesc = "Progress in pmapsum : ",kwargs...)
402428
end
@@ -439,21 +465,15 @@ julia> pmapsum_elementwise(x->x^2,1:200)
439465
julia> pmapsum_elementwise((x,y)-> x+y, (1:5,1:2))
440466
45
441467
442-
julia> pmapsum_elementwise(x->(sleep(myid());x^2), 1:5, showprogress=true, progressdesc = "Progress : ")
443-
Progress : 100%|███████████████████████████████████████| Time: 0:00:05
444-
map: 4
445-
reduce: 4
446-
55
447-
448468
julia> pmapsum_elementwise(x->x^2, Float64, 1:5)
449469
55.0
450470
```
451471
452472
See also: [`pmapreduce_commutative_elementwise`](@ref), [`pmapsum`](@ref)
453473
"""
454-
function pmapsum_elementwise(fmap::Function,::Type{T},iterable,args...;
474+
function pmapsum_elementwise(fmap::Function,T::Type,iterable,args...;
455475
showprogress::Bool = false, progressdesc = "Progress in pmapsum : ",
456-
kwargs...) where {T}
476+
kwargs...)
457477

458478
pmapsum(plist->sum(x->fmap(x...,args...;kwargs...),plist),T,iterable,
459479
showprogress = showprogress, progressdesc = progressdesc)
@@ -510,14 +530,6 @@ julia> pmapreduce(x->ones(2).*myid(), x->hcat(x...), 1:4)
510530
2.0 3.0 4.0 5.0
511531
2.0 3.0 4.0 5.0
512532
513-
julia> pmapreduce(x->(sleep(myid());ones(2).*myid()), x->hcat(x...), 1:4, showprogress=true, progressdesc="Progress : ")
514-
Progress : 100%|██████████████████████████████████████| Time: 0:00:05
515-
map: 4
516-
reduce: 4
517-
2×4 Array{Float64,2}:
518-
2.0 3.0 4.0 5.0
519-
2.0 3.0 4.0 5.0
520-
521533
julia> pmapreduce(x->ones(2).*myid(), Vector{Int64}, x->hcat(x...), Matrix{Int64}, 1:4)
522534
2×4 Array{Int64,2}:
523535
2 3 4 5
@@ -526,8 +538,8 @@ julia> pmapreduce(x->ones(2).*myid(), Vector{Int64}, x->hcat(x...), Matrix{Int64
526538
527539
See also: [`pmapreduce_commutative`](@ref), [`pmapsum`](@ref)
528540
"""
529-
function pmapreduce(fmap::Function,::Type{Tmap},freduce::Function,::Type{Tred},
530-
iterators::Tuple,args...;kwargs...) where {Tmap,Tred}
541+
function pmapreduce(fmap::Function,Tmap::Type,freduce::Function,Tred::Type,
542+
iterators::Tuple,args...;kwargs...)
531543

532544
tree,branches = createbranchchannels(pval{Tmap},pval{Tred},
533545
iterators,OrderedBinaryTree)
@@ -541,8 +553,8 @@ function pmapreduce(fmap::Function,freduce::Function,iterators::Tuple,args...;
541553
pmapreduce(fmap,Any,freduce,Any,iterators,args...;kwargs...)
542554
end
543555

544-
function pmapreduce(fmap::Function,::Type{Tmap},freduce::Function,::Type{Tred},
545-
iterable,args...;kwargs...) where {Tmap,Tred}
556+
function pmapreduce(fmap::Function,Tmap::Type,freduce::Function,Tred::Type,
557+
iterable,args...;kwargs...)
546558

547559
pmapreduce(fmap,Tmap,freduce,Tred,(iterable,),args...;kwargs...)
548560
end

0 commit comments

Comments
 (0)