@@ -169,25 +169,41 @@ get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_h
169
169
function pmapsum (f:: Function ,iterable,args... ;kwargs... )
170
170
171
171
procs_used = workers_active (iterable)
172
- num_workers = length (procs_used)
173
- hostnames = get_hostnames (procs_used)
174
- nodes = get_nodes (hostnames)
175
- pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes]
172
+ num_workers = length (procs_used);
173
+ hostnames = get_hostnames (procs_used);
174
+ nodes = get_nodes (hostnames);
175
+ pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
176
+
177
+ nprocs_node = get_nprocs_node (procs_used)
178
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {Any} (nprocs_node[node]),pid_node)
179
+ for (node,pid_node) in zip (nodes,pid_rank0_on_node))
176
180
177
181
# Worker at which final reduction takes place
178
182
p_final = first (pid_rank0_on_node)
179
183
180
- sum_channel = RemoteChannel (()-> Channel {Any} (100 ),p_final)
181
- result = 0
184
+ sum_channel = RemoteChannel (()-> Channel {Any} (length (pid_rank0_on_node) ),p_final)
185
+ result = nothing
182
186
183
187
# Run the function on each processor and compute the sum at each node
184
- @sync for (rank,p) in enumerate (procs_used)
188
+ @sync for (rank,(p,node)) in enumerate (zip ( procs_used,hostnames) )
185
189
@async begin
190
+
186
191
iterable_on_proc = split_across_processors (iterable,num_workers,rank)
187
192
r = @spawnat p f (iterable_on_proc,args... ;kwargs... )
188
- @spawnat p put! (sum_channel,fetch (r))
193
+
194
+ node_remotechannel = node_channels[node]
195
+ np_node = nprocs_node[node]
196
+
197
+ @spawnat p put! (node_remotechannel,fetch (r))
198
+
199
+ if p in pid_rank0_on_node
200
+ s = @spawnat p sum (take! (node_remotechannel) for i= 1 : np_node)
201
+ @spawnat p put! (sum_channel,fetch (s))
202
+ end
203
+
189
204
if p== p_final
190
- result = @fetchfrom p_final sum (take! (sum_channel) for i= 1 : num_workers)
205
+ result = @fetchfrom p_final sum (take! (sum_channel)
206
+ for i= 1 : length (pid_rank0_on_node))
191
207
end
192
208
end
193
209
end
0 commit comments