@@ -18,30 +18,30 @@ function worker_rank()
18
18
end
19
19
20
20
function split_across_processors (num_tasks:: Integer ,num_procs= nworkers (),proc_id= worker_rank ())
21
- if num_procs == 1
22
- return num_tasks
23
- end
21
+ if num_procs == 1
22
+ return num_tasks
23
+ end
24
24
25
- num_tasks_per_process,num_tasks_leftover = div (num_tasks,num_procs),mod (num_tasks,num_procs)
25
+ num_tasks_per_process,num_tasks_leftover = div (num_tasks,num_procs),mod (num_tasks,num_procs)
26
26
27
- num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod (num_tasks,num_procs) ? 1 : 0 );
28
- task_start = num_tasks_per_process* (proc_id- 1 ) + min (num_tasks_leftover+ 1 ,proc_id);
27
+ num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod (num_tasks,num_procs) ? 1 : 0 );
28
+ task_start = num_tasks_per_process* (proc_id- 1 ) + min (num_tasks_leftover+ 1 ,proc_id);
29
29
30
- return task_start: (task_start+ num_tasks_on_proc- 1 )
30
+ return task_start: (task_start+ num_tasks_on_proc- 1 )
31
31
end
32
32
33
33
function split_across_processors (arr₁,num_procs= nworkers (),proc_id= worker_rank ())
34
34
35
- @assert (proc_id<= num_procs," processor rank has to be less than number of workers engaged" )
35
+ @assert (proc_id<= num_procs," processor rank has to be less than number of workers engaged" )
36
36
37
- num_tasks = length (arr₁);
37
+ num_tasks = length (arr₁);
38
38
39
- num_tasks_per_process,num_tasks_leftover = div (num_tasks,num_procs),mod (num_tasks,num_procs)
39
+ num_tasks_per_process,num_tasks_leftover = div (num_tasks,num_procs),mod (num_tasks,num_procs)
40
40
41
- num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod (num_tasks,num_procs) ? 1 : 0 );
42
- task_start = num_tasks_per_process* (proc_id- 1 ) + min (num_tasks_leftover+ 1 ,proc_id);
41
+ num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod (num_tasks,num_procs) ? 1 : 0 );
42
+ task_start = num_tasks_per_process* (proc_id- 1 ) + min (num_tasks_leftover+ 1 ,proc_id);
43
43
44
- Iterators. take (Iterators. drop (arr₁,task_start- 1 ),num_tasks_on_proc)
44
+ Iterators. take (Iterators. drop (arr₁,task_start- 1 ),num_tasks_on_proc)
45
45
end
46
46
47
47
function split_product_across_processors (arr₁:: AbstractVector ,arr₂:: AbstractVector ,
@@ -152,6 +152,7 @@ function procid_and_mode_index(arr₁::AbstractVector,arr₂::AbstractVector,
152
152
end
153
153
154
154
function procid_and_mode_index (iter,val:: Tuple ,num_procs:: Integer )
155
+
155
156
proc_id_mode = get_processor_id_from_split_array (iter,val,num_procs)
156
157
modes_in_procid_file = split_across_processors (iter,num_procs,proc_id_mode)
157
158
mode_index = get_index_in_split_array (modes_in_procid_file,val)
160
161
161
162
function mode_index_in_file (arr₁:: AbstractVector ,arr₂:: AbstractVector ,
162
163
(arr₁_value,arr₂_value):: Tuple ,num_procs:: Integer ,proc_id_mode:: Integer )
164
+
163
165
modes_in_procid_file = split_product_across_processors (arr₁,arr₂,num_procs,proc_id_mode)
164
166
mode_index = get_index_in_split_array (modes_in_procid_file,(arr₁_value,arr₂_value))
165
167
end
@@ -229,7 +231,11 @@ get_nodes(procs_used::Vector{<:Integer}=workers()) = get_nodes(get_hostnames(pro
229
231
230
232
function get_nprocs_node (hostnames:: Vector{String} )
231
233
nodes = get_nodes (hostnames)
232
- num_procs_node = Dict (node=> count (x-> x== node,hostnames) for node in nodes)
234
+ get_nprocs_node (hostnames,nodes)
235
+ end
236
+
237
+ function get_nprocs_node (hostnames:: Vector{String} ,nodes:: Vector{String} )
238
+ Dict (node=> count (isequal (node),hostnames) for node in nodes)
233
239
end
234
240
235
241
get_nprocs_node (procs_used:: Vector{<:Integer} = workers ()) = get_nprocs_node (get_hostnames (procs_used))
@@ -238,40 +244,46 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
238
244
239
245
procs_used = workers_active (iterable)
240
246
241
- futures = pmap_onebatch_per_worker (f,iterable,args... ;kwargs... )
247
+ num_workers = length (procs_used);
248
+ hostnames = get_hostnames (procs_used);
249
+ nodes = get_nodes (hostnames);
250
+ pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
242
251
243
- function final_sum (futures)
244
- s = fetch (first (futures))
245
- @sync for f in futures[2 : end ]
246
- @async begin
247
- s += fetch (f)
248
- end
249
- end
250
- return s
251
- end
252
- @fetchfrom first (procs_used) final_sum (futures)
253
- end
252
+ nprocs_node = get_nprocs_node (procs_used)
253
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {Any} (nprocs_node[node]),pid_node)
254
+ for (node,pid_node) in zip (nodes,pid_rank0_on_node))
254
255
255
- function pmapsum_timed (f:: Function ,iterable,args... ;kwargs... )
256
+ # Worker at which final reduction takes place
257
+ p_final = first (pid_rank0_on_node)
256
258
257
- procs_used = workers_active (iterable)
259
+ sum_channel = RemoteChannel (()-> Channel {Any} (length (pid_rank0_on_node)),p_final)
260
+ result = nothing
261
+
262
+ # Run the function on each processor and compute the sum at each node
263
+ @sync for (rank,(p,node)) in enumerate (zip (procs_used,hostnames))
264
+ @async begin
265
+
266
+ iterable_on_proc = split_across_processors (iterable,num_workers,rank)
267
+ r = @spawnat p f (iterable_on_proc,args... ;kwargs... )
268
+
269
+ node_remotechannel = node_channels[node]
270
+ np_node = nprocs_node[node]
271
+
272
+ @spawnat p put! (node_remotechannel,fetch (r))
258
273
259
- futures = pmap_onebatch_per_worker (f,iterable,args... ;kwargs... )
274
+ if p in pid_rank0_on_node
275
+ s = @spawnat p sum (take! (node_remotechannel) for i= 1 : np_node)
276
+ @spawnat p put! (sum_channel,fetch (s))
277
+ end
260
278
261
- timer = TimerOutput ()
262
- function final_sum (futures,timer)
263
- @timeit timer " fetch" s = fetch (first (futures))
264
- @sync for f in futures[2 : end ]
265
- @async begin
266
- @timeit timer " fetch" s += fetch (f)
279
+ if p== p_final
280
+ result = @fetchfrom p_final sum (take! (sum_channel)
281
+ for i= 1 : length (pid_rank0_on_node))
267
282
end
268
283
end
269
- return s,timer
270
284
end
271
285
272
- s,timer = @fetchfrom first (procs_used) final_sum (futures,timer)
273
- println (timer)
274
- return s
286
+ return result
275
287
end
276
288
277
289
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;kwargs... )
@@ -293,10 +305,4 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
293
305
return futures
294
306
end
295
307
296
- function sum_at_node (futures:: Vector{Future} ,hostnames)
297
- myhost = hostnames[worker_rank ()]
298
- futures_on_myhost = futures[hostnames .== myhost]
299
- sum (fetch (f) for f in futures_on_myhost)
300
- end
301
-
302
308
end # module
0 commit comments