Skip to content

Commit 8a6df80

Browse files
committed
only one sum channel
1 parent 4718f05 commit 8a6df80

File tree

1 file changed

+9
-48
lines changed

1 file changed

+9
-48
lines changed

src/ParallelUtilities.jl

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -172,48 +172,27 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
172172
num_workers = length(procs_used)
173173
hostnames = get_hostnames(procs_used)
174174
nodes = get_nodes(hostnames)
175-
np_nodes = get_nprocs_node(hostnames)
176175
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes]
177176

178-
function apply_and_stash(f,iterable,args...;channel,kwargs...)
179-
result = f(iterable,args...;kwargs...)
180-
put!(channel,result)
181-
end
182-
183177
# Worker at which final reduction takes place
184178
p_final = first(pid_rank0_on_node)
185-
sum_channel = RemoteChannel(()->Channel{Any}(10),p_final)
186-
master_channel_nodes = Dict(node=>RemoteChannel(()->Channel{Any}(10),p)
187-
for (node,p) in zip(nodes,pid_rank0_on_node))
188-
worker_channel_nodes = Dict(node=>RemoteChannel(()->Channel{Any}(10),p)
189-
for (node,p) in zip(nodes,pid_rank0_on_node))
190179

191-
# Run the function on each processor
180+
sum_channel = RemoteChannel(()->Channel{Any}(100),p_final)
181+
result = 0
182+
183+
# Run the function on each processor and compute the sum at each node
192184
@sync for (rank,p) in enumerate(procs_used)
193185
@async begin
194186
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
195-
node = hostnames[rank]
196-
master_channel_node = master_channel_nodes[node]
197-
worker_channel_node = worker_channel_nodes[node]
198-
np_node = np_nodes[node]
199-
if p in pid_rank0_on_node
200-
@spawnat p apply_and_stash(f,iterable_on_proc,args...;kwargs...,
201-
channel=master_channel_node)
202-
@spawnat p sum_channel(worker_channel_node,master_channel_node,np_node)
203-
else
204-
@spawnat p apply_and_stash(f,iterable_on_proc,args...;kwargs...,
205-
channel=worker_channel_node)
187+
r = @spawnat p f(iterable_on_proc,args...;kwargs...)
188+
@spawnat p put!(sum_channel,fetch(r))
189+
if p==p_final
190+
result = @fetchfrom p_final sum(take!(sum_channel) for i=1:num_workers)
206191
end
207192
end
208-
@async begin
209-
@spawnat p_final sum_channel(master_channel_nodes,sum_channel,length(nodes))
210-
end
211193
end
212194

213-
finalize.(values(worker_channel_nodes))
214-
finalize.(values(master_channel_nodes))
215-
216-
take!(sum_channel)
195+
return result
217196
end
218197

219198
function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothing,kwargs...)
@@ -234,24 +213,6 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothi
234213
return futures
235214
end
236215

237-
function spawnf(f::Function,iterable)
238-
239-
futures = Vector{Future}(undef,nworkers())
240-
@sync for (rank,p) in enumerate(workers())
241-
futures[rank] = @spawnat p f(iterable)
242-
end
243-
return futures
244-
end
245-
246-
function sum_channel(worker_channel,master_channel,np_node)
247-
@sync for i in 1:np_node-1
248-
@async begin
249-
s = take!(worker_channel) + take!(master_channel)
250-
put!(master_channel,s)
251-
end
252-
end
253-
end
254-
255216
#############################################################################
256217

257218
export split_across_processors,split_product_across_processors,

0 commit comments

Comments
 (0)