Skip to content

Commit 5ca5b4a

Browse files
committed
intermediate sum at nodes
1 parent 8a6df80 commit 5ca5b4a

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

src/ParallelUtilities.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,25 +169,41 @@ get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_h
169169
function pmapsum(f::Function,iterable,args...;kwargs...)
170170

171171
procs_used = workers_active(iterable)
172-
num_workers = length(procs_used)
173-
hostnames = get_hostnames(procs_used)
174-
nodes = get_nodes(hostnames)
175-
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes]
172+
num_workers = length(procs_used);
173+
hostnames = get_hostnames(procs_used);
174+
nodes = get_nodes(hostnames);
175+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
176+
177+
nprocs_node = get_nprocs_node(procs_used)
178+
node_channels = Dict(node=>RemoteChannel(()->Channel{Any}(nprocs_node[node]),pid_node)
179+
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
176180

177181
# Worker at which final reduction takes place
178182
p_final = first(pid_rank0_on_node)
179183

180-
sum_channel = RemoteChannel(()->Channel{Any}(100),p_final)
181-
result = 0
184+
sum_channel = RemoteChannel(()->Channel{Any}(length(pid_rank0_on_node)),p_final)
185+
result = nothing
182186

183187
# Run the function on each processor and compute the sum at each node
184-
@sync for (rank,p) in enumerate(procs_used)
188+
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
185189
@async begin
190+
186191
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
187192
r = @spawnat p f(iterable_on_proc,args...;kwargs...)
188-
@spawnat p put!(sum_channel,fetch(r))
193+
194+
node_remotechannel = node_channels[node]
195+
np_node = nprocs_node[node]
196+
197+
@spawnat p put!(node_remotechannel,fetch(r))
198+
199+
if p in pid_rank0_on_node
200+
s = @spawnat p sum(take!(node_remotechannel) for i=1:np_node)
201+
@spawnat p put!(sum_channel,fetch(s))
202+
end
203+
189204
if p==p_final
190-
result = @fetchfrom p_final sum(take!(sum_channel) for i=1:num_workers)
205+
result = @fetchfrom p_final sum(take!(sum_channel)
206+
for i=1:length(pid_rank0_on_node))
191207
end
192208
end
193209
end

0 commit comments

Comments
 (0)