@@ -237,24 +237,15 @@ get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_h
237
237
function pmapsum (f:: Function ,iterable,args... ;kwargs... )
238
238
239
239
procs_used = workers_active (iterable)
240
- num_workers = length (procs_used)
241
- hostnames = get_hostnames (procs_used)
242
- nodes = get_nodes (hostnames)
243
- pid_rank0_on_node = [procs_used[findfirst (x-> x== node,hostnames)] for node in nodes]
244
240
245
241
futures = pmap_onebatch_per_worker (f,iterable,args... ;kwargs... )
246
242
247
- # Intermediate sum over processors on the same node
248
- node_sum_futures = Vector {Future} (undef, length (pid_rank0_on_node) )
249
- @sync for (ind,p) in enumerate (pid_rank0_on_node)
250
- @async node_sum_futures[ind] = @spawnat p sum_at_node (futures,hostnames )
243
+ # Final sum across all nodes
244
+ # sum(fetch(f) for f in futures )
245
+ @fetchfrom first (procs_used) @distributed ( + ) for f in futures
246
+ fetch (f )
251
247
end
252
248
253
- # Worker at which final reduction takes place
254
- p = first (pid_rank0_on_node)
255
-
256
- # Final sum across all nodes
257
- @fetchfrom p sum (fetch (f) for f in node_sum_futures)
258
249
end
259
250
260
251
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;num_workers= nothing ,kwargs... )
0 commit comments