@@ -7,7 +7,7 @@ export split_across_processors,split_product_across_processors,
7
7
get_processor_id_from_split_array,procid_allmodes,mode_index_in_file,
8
8
get_processor_range_from_split_array,workers_active,nworkers_active,worker_rank,
9
9
get_index_in_split_array,procid_and_mode_index,extrema_from_split_array,
10
- pmapsum,pmapsum_timed,sum_at_node ,pmap_onebatch_per_worker,moderanges_common_lastarray,
10
+ pmapsum,pmapreduce ,pmap_onebatch_per_worker,moderanges_common_lastarray,
11
11
get_nodes,get_hostnames,get_nprocs_node
12
12
13
13
function worker_rank ()
242
242
243
243
get_nprocs_node (procs_used:: Vector{<:Integer} = workers ()) = get_nprocs_node (get_hostnames (procs_used))
244
244
245
+ # This function does not sort the values, so it might be faster
245
246
function pmapsum_remotechannel (:: Type{T} ,f:: Function ,iterable,args... ;kwargs... ) where {T}
246
247
247
248
procs_used = workers_active (iterable)
@@ -290,8 +291,73 @@ function pmapsum_remotechannel(::Type{T},f::Function,iterable,args...;kwargs...)
290
291
return result :: T
291
292
end
292
293
294
+ # Store the processor id with the value
295
+ struct pval{T}
296
+ p :: Int
297
+ parent :: T
298
+ end
299
+
300
+ function pmapreduce_remotechannel (:: Type{T} ,fmap:: Function ,freduce:: Function ,
301
+ iterable,args... ;kwargs... ) where {T}
302
+
303
+ procs_used = workers_active (iterable)
304
+
305
+ num_workers = length (procs_used);
306
+ hostnames = get_hostnames (procs_used);
307
+ nodes = get_nodes (hostnames);
308
+ pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
309
+
310
+ nprocs_node = get_nprocs_node (procs_used)
311
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node[node]),pid_node)
312
+ for (node,pid_node) in zip (nodes,pid_rank0_on_node))
313
+
314
+ # Worker at which final reduction takes place
315
+ p_final = first (pid_rank0_on_node)
316
+
317
+ reduce_channel = RemoteChannel (()-> Channel {T} (length (pid_rank0_on_node)),p_final)
318
+ result = nothing
319
+
320
+ # Run the function on each processor and compute the sum at each node
321
+ @sync for (rank,(p,node)) in enumerate (zip (procs_used,hostnames))
322
+ @async begin
323
+
324
+ node_remotechannel = node_channels[node]
325
+ np_node = nprocs_node[node]
326
+
327
+ iterable_on_proc = split_across_processors (iterable,num_workers,rank)
328
+ @spawnat p put! (node_remotechannel,
329
+ pval (p,fmap (iterable_on_proc,args... ;kwargs... )))
330
+
331
+ @async if p in pid_rank0_on_node
332
+ f = @spawnat p begin
333
+ vals = [take! (node_remotechannel) for i= 1 : np_node ]
334
+ sort! (vals,by= x-> x. p)
335
+ put! (reduce_channel,pval (p,freduce (v. parent for v in vals)) )
336
+ end
337
+ wait (f)
338
+ @spawnat p finalize (node_remotechannel)
339
+ end
340
+
341
+ @async if p== p_final
342
+ result = @fetchfrom p_final begin
343
+ vals = [take! (reduce_channel) for i= 1 : length (pid_rank0_on_node)]
344
+ sort! (vals,by= x-> x. p)
345
+ freduce (v. parent for v in vals)
346
+ end
347
+ @spawnat p finalize (reduce_channel)
348
+ end
349
+ end
350
+ end
351
+
352
+ return result :: T
353
+ end
354
+
293
355
function pmapsum_remotechannel (f:: Function ,iterable,args... ;kwargs... )
294
- pmapsum (Any,f,iterable,args... ;kwargs... )
356
+ pmapsum_remotechannel (Any,f,iterable,args... ;kwargs... )
357
+ end
358
+
359
+ function pmapreduce_remotechannel (fmap:: Function ,freduce:: Function ,iterable,args... ;kwargs... )
360
+ pmapreduce_remotechannel (Any,fmap,freduce,iterable,args... ;kwargs... )
295
361
end
296
362
297
363
function pmapsum_distributedfor (f:: Function ,iterable,args... ;kwargs... )
@@ -302,7 +368,16 @@ function pmapsum_distributedfor(f::Function,iterable,args...;kwargs...)
302
368
end
303
369
end
304
370
371
+ function pmapreduce_distributedfor (fmap:: Function ,freduce:: Function ,iterable,args... ;kwargs... )
372
+ @distributed freduce for i in 1 : nworkers ()
373
+ np = nworkers_active (iterable)
374
+ iter_proc = split_across_processors (iterable,np,i)
375
+ fmap (iter_proc,args... ;kwargs... )
376
+ end
377
+ end
378
+
305
379
pmapsum (args... ;kwargs... ) = pmapsum_remotechannel (args... ;kwargs... )
380
+ pmapreduce (args... ;kwargs... ) = pmapreduce_remotechannel (args... ;kwargs... )
306
381
307
382
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;kwargs... )
308
383
0 commit comments