Skip to content

Commit d028237

Browse files
committed
added pmapreduce
1 parent a6cbcd3 commit d028237

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

src/ParallelUtilities.jl

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export split_across_processors,split_product_across_processors,
77
get_processor_id_from_split_array,procid_allmodes,mode_index_in_file,
88
get_processor_range_from_split_array,workers_active,nworkers_active,worker_rank,
99
get_index_in_split_array,procid_and_mode_index,extrema_from_split_array,
10-
pmapsum,pmapsum_timed,sum_at_node,pmap_onebatch_per_worker,moderanges_common_lastarray,
10+
pmapsum,pmapreduce,pmap_onebatch_per_worker,moderanges_common_lastarray,
1111
get_nodes,get_hostnames,get_nprocs_node
1212

1313
function worker_rank()
@@ -242,6 +242,7 @@ end
242242

243243
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
244244

245+
# This function does not sort the values, so it might be faster
245246
function pmapsum_remotechannel(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
246247

247248
procs_used = workers_active(iterable)
@@ -290,8 +291,73 @@ function pmapsum_remotechannel(::Type{T},f::Function,iterable,args...;kwargs...)
290291
return result :: T
291292
end
292293

294+
# Store the processor id with the value
295+
struct pval{T}
296+
p :: Int
297+
parent :: T
298+
end
299+
300+
function pmapreduce_remotechannel(::Type{T},fmap::Function,freduce::Function,
301+
iterable,args...;kwargs...) where {T}
302+
303+
procs_used = workers_active(iterable)
304+
305+
num_workers = length(procs_used);
306+
hostnames = get_hostnames(procs_used);
307+
nodes = get_nodes(hostnames);
308+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
309+
310+
nprocs_node = get_nprocs_node(procs_used)
311+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
312+
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
313+
314+
# Worker at which final reduction takes place
315+
p_final = first(pid_rank0_on_node)
316+
317+
reduce_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
318+
result = nothing
319+
320+
# Run the function on each processor and compute the sum at each node
321+
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
322+
@async begin
323+
324+
node_remotechannel = node_channels[node]
325+
np_node = nprocs_node[node]
326+
327+
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
328+
@spawnat p put!(node_remotechannel,
329+
pval(p,fmap(iterable_on_proc,args...;kwargs...)))
330+
331+
@async if p in pid_rank0_on_node
332+
f = @spawnat p begin
333+
vals = [take!(node_remotechannel) for i=1:np_node ]
334+
sort!(vals,by=x->x.p)
335+
put!(reduce_channel,pval(p,freduce(v.parent for v in vals)) )
336+
end
337+
wait(f)
338+
@spawnat p finalize(node_remotechannel)
339+
end
340+
341+
@async if p==p_final
342+
result = @fetchfrom p_final begin
343+
vals = [take!(reduce_channel) for i=1:length(pid_rank0_on_node)]
344+
sort!(vals,by=x->x.p)
345+
freduce(v.parent for v in vals)
346+
end
347+
@spawnat p finalize(reduce_channel)
348+
end
349+
end
350+
end
351+
352+
return result :: T
353+
end
354+
293355
function pmapsum_remotechannel(f::Function,iterable,args...;kwargs...)
294-
pmapsum(Any,f,iterable,args...;kwargs...)
356+
pmapsum_remotechannel(Any,f,iterable,args...;kwargs...)
357+
end
358+
359+
function pmapreduce_remotechannel(fmap::Function,freduce::Function,iterable,args...;kwargs...)
360+
pmapreduce_remotechannel(Any,fmap,freduce,iterable,args...;kwargs...)
295361
end
296362

297363
function pmapsum_distributedfor(f::Function,iterable,args...;kwargs...)
@@ -302,7 +368,16 @@ function pmapsum_distributedfor(f::Function,iterable,args...;kwargs...)
302368
end
303369
end
304370

371+
function pmapreduce_distributedfor(fmap::Function,freduce::Function,iterable,args...;kwargs...)
372+
@distributed freduce for i in 1:nworkers()
373+
np = nworkers_active(iterable)
374+
iter_proc = split_across_processors(iterable,np,i)
375+
fmap(iter_proc,args...;kwargs...)
376+
end
377+
end
378+
305379
pmapsum(args...;kwargs...) = pmapsum_remotechannel(args...;kwargs...)
380+
pmapreduce(args...;kwargs...) = pmapreduce_remotechannel(args...;kwargs...)
306381

307382
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
308383

0 commit comments

Comments
 (0)