Skip to content

Commit 4cee9bd

Browse files
committed
finalize remotechannel
1 parent f308d03 commit 4cee9bd

File tree

1 file changed

+52
-3
lines changed

1 file changed

+52
-3
lines changed

src/ParallelUtilities.jl

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,56 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
261261
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262262
result = nothing
263263

264+
# Run the function on each processor and compute the sum at each node
265+
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266+
@async begin
267+
268+
node_remotechannel = node_channels[node]
269+
np_node = nprocs_node[node]
270+
271+
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
272+
@spawnat p put!(node_remotechannel,
273+
f(iterable_on_proc,args...;kwargs...))
274+
275+
@async if p in pid_rank0_on_node
276+
f = @spawnat p put!(sum_channel,
277+
sum(take!(node_remotechannel) for i=1:np_node))
278+
wait(f)
279+
@spawnat p finalize(node_remotechannel)
280+
end
281+
282+
@async if p==p_final
283+
result = @fetchfrom p_final sum(take!(sum_channel)
284+
for i=1:length(pid_rank0_on_node))
285+
@spawnat p finalize(sum_channel)
286+
end
287+
end
288+
end
289+
290+
return result :: T
291+
end
292+
293+
function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T<:AbstractArray}
294+
295+
# Use ArrayChannels
296+
297+
procs_used = workers_active(iterable)
298+
299+
num_workers = length(procs_used);
300+
hostnames = get_hostnames(procs_used);
301+
nodes = get_nodes(hostnames);
302+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
303+
304+
nprocs_node = get_nprocs_node(procs_used)
305+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
306+
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
307+
308+
# Worker at which final reduction takes place
309+
p_final = first(pid_rank0_on_node)
310+
311+
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
312+
result = nothing
313+
264314
# Run the function on each processor and compute the sum at each node
265315
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266316
@async begin
@@ -285,12 +335,11 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
285335
end
286336
end
287337

288-
return result
338+
return result :: T
289339
end
290340

291341
function pmapsum(f::Function,iterable,args...;kwargs...)
292-
T = promote_type(Base.return_types(f)...)
293-
pmapsum(T,f,iterable,args...;kwargs...)
342+
pmapsum(Any,f,iterable,args...;kwargs...)
294343
end
295344

296345
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)

0 commit comments

Comments
 (0)