@@ -261,6 +261,56 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
261
261
sum_channel = RemoteChannel (()-> Channel {T} (length (pid_rank0_on_node)),p_final)
262
262
result = nothing
263
263
264
+ # Run the function on each processor and compute the sum at each node
265
+ @sync for (rank,(p,node)) in enumerate (zip (procs_used,hostnames))
266
+ @async begin
267
+
268
+ node_remotechannel = node_channels[node]
269
+ np_node = nprocs_node[node]
270
+
271
+ iterable_on_proc = split_across_processors (iterable,num_workers,rank)
272
+ @spawnat p put! (node_remotechannel,
273
+ f (iterable_on_proc,args... ;kwargs... ))
274
+
275
+ @async if p in pid_rank0_on_node
276
+ f = @spawnat p put! (sum_channel,
277
+ sum (take! (node_remotechannel) for i= 1 : np_node))
278
+ wait (f)
279
+ @spawnat p finalize (node_remotechannel)
280
+ end
281
+
282
+ @async if p== p_final
283
+ result = @fetchfrom p_final sum (take! (sum_channel)
284
+ for i= 1 : length (pid_rank0_on_node))
285
+ @spawnat p finalize (sum_channel)
286
+ end
287
+ end
288
+ end
289
+
290
+ return result :: T
291
+ end
292
+
293
+ function pmapsum (:: Type{T} ,f:: Function ,iterable,args... ;kwargs... ) where {T<: AbstractArray }
294
+
295
+ # Use ArrayChannels
296
+
297
+ procs_used = workers_active (iterable)
298
+
299
+ num_workers = length (procs_used);
300
+ hostnames = get_hostnames (procs_used);
301
+ nodes = get_nodes (hostnames);
302
+ pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
303
+
304
+ nprocs_node = get_nprocs_node (procs_used)
305
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node[node]),pid_node)
306
+ for (node,pid_node) in zip (nodes,pid_rank0_on_node))
307
+
308
+ # Worker at which final reduction takes place
309
+ p_final = first (pid_rank0_on_node)
310
+
311
+ sum_channel = RemoteChannel (()-> Channel {T} (length (pid_rank0_on_node)),p_final)
312
+ result = nothing
313
+
264
314
# Run the function on each processor and compute the sum at each node
265
315
@sync for (rank,(p,node)) in enumerate (zip (procs_used,hostnames))
266
316
@async begin
@@ -285,12 +335,11 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
285
335
end
286
336
end
287
337
288
- return result
338
+ return result :: T
289
339
end
290
340
291
341
function pmapsum (f:: Function ,iterable,args... ;kwargs... )
292
- T = promote_type (Base. return_types (f)... )
293
- pmapsum (T,f,iterable,args... ;kwargs... )
342
+ pmapsum (Any,f,iterable,args... ;kwargs... )
294
343
end
295
344
296
345
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;kwargs... )
0 commit comments