Skip to content

Commit e1be0a5

Browse files
committed
merged remotechannels
2 parents a3a0d38 + 088112d commit e1be0a5

File tree

1 file changed

+51
-45
lines changed

1 file changed

+51
-45
lines changed

src/ParallelUtilities.jl

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,30 @@ function worker_rank()
1818
end
1919

2020
function split_across_processors(num_tasks::Integer,num_procs=nworkers(),proc_id=worker_rank())
21-
if num_procs == 1
22-
return num_tasks
23-
end
21+
if num_procs == 1
22+
return num_tasks
23+
end
2424

25-
num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs)
25+
num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs)
2626

27-
num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 );
28-
task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id);
27+
num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 );
28+
task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id);
2929

30-
return task_start:(task_start+num_tasks_on_proc-1)
30+
return task_start:(task_start+num_tasks_on_proc-1)
3131
end
3232

3333
function split_across_processors(arr₁,num_procs=nworkers(),proc_id=worker_rank())
3434

35-
@assert(proc_id<=num_procs,"processor rank has to be less than number of workers engaged")
35+
@assert(proc_id<=num_procs,"processor rank has to be less than number of workers engaged")
3636

37-
num_tasks = length(arr₁);
37+
num_tasks = length(arr₁);
3838

39-
num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs)
39+
num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs)
4040

41-
num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 );
42-
task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id);
41+
num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 );
42+
task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id);
4343

44-
Iterators.take(Iterators.drop(arr₁,task_start-1),num_tasks_on_proc)
44+
Iterators.take(Iterators.drop(arr₁,task_start-1),num_tasks_on_proc)
4545
end
4646

4747
function split_product_across_processors(arr₁::AbstractVector,arr₂::AbstractVector,
@@ -152,6 +152,7 @@ function procid_and_mode_index(arr₁::AbstractVector,arr₂::AbstractVector,
152152
end
153153

154154
function procid_and_mode_index(iter,val::Tuple,num_procs::Integer)
155+
155156
proc_id_mode = get_processor_id_from_split_array(iter,val,num_procs)
156157
modes_in_procid_file = split_across_processors(iter,num_procs,proc_id_mode)
157158
mode_index = get_index_in_split_array(modes_in_procid_file,val)
@@ -160,6 +161,7 @@ end
160161

161162
function mode_index_in_file(arr₁::AbstractVector,arr₂::AbstractVector,
162163
(arr₁_value,arr₂_value)::Tuple,num_procs::Integer,proc_id_mode::Integer)
164+
163165
modes_in_procid_file = split_product_across_processors(arr₁,arr₂,num_procs,proc_id_mode)
164166
mode_index = get_index_in_split_array(modes_in_procid_file,(arr₁_value,arr₂_value))
165167
end
@@ -229,7 +231,11 @@ get_nodes(procs_used::Vector{<:Integer}=workers()) = get_nodes(get_hostnames(pro
229231

230232
function get_nprocs_node(hostnames::Vector{String})
231233
nodes = get_nodes(hostnames)
232-
num_procs_node = Dict(node=>count(x->x==node,hostnames) for node in nodes)
234+
get_nprocs_node(hostnames,nodes)
235+
end
236+
237+
function get_nprocs_node(hostnames::Vector{String},nodes::Vector{String})
238+
Dict(node=>count(isequal(node),hostnames) for node in nodes)
233239
end
234240

235241
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
@@ -238,40 +244,46 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
238244

239245
procs_used = workers_active(iterable)
240246

241-
futures = pmap_onebatch_per_worker(f,iterable,args...;kwargs...)
247+
num_workers = length(procs_used);
248+
hostnames = get_hostnames(procs_used);
249+
nodes = get_nodes(hostnames);
250+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
242251

243-
function final_sum(futures)
244-
s = fetch(first(futures))
245-
@sync for f in futures[2:end]
246-
@async begin
247-
s += fetch(f)
248-
end
249-
end
250-
return s
251-
end
252-
@fetchfrom first(procs_used) final_sum(futures)
253-
end
252+
nprocs_node = get_nprocs_node(procs_used)
253+
node_channels = Dict(node=>RemoteChannel(()->Channel{Any}(nprocs_node[node]),pid_node)
254+
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
254255

255-
function pmapsum_timed(f::Function,iterable,args...;kwargs...)
256+
# Worker at which final reduction takes place
257+
p_final = first(pid_rank0_on_node)
256258

257-
procs_used = workers_active(iterable)
259+
sum_channel = RemoteChannel(()->Channel{Any}(length(pid_rank0_on_node)),p_final)
260+
result = nothing
261+
262+
# Run the function on each processor and compute the sum at each node
263+
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
264+
@async begin
265+
266+
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
267+
r = @spawnat p f(iterable_on_proc,args...;kwargs...)
268+
269+
node_remotechannel = node_channels[node]
270+
np_node = nprocs_node[node]
271+
272+
@spawnat p put!(node_remotechannel,fetch(r))
258273

259-
futures = pmap_onebatch_per_worker(f,iterable,args...;kwargs...)
274+
if p in pid_rank0_on_node
275+
s = @spawnat p sum(take!(node_remotechannel) for i=1:np_node)
276+
@spawnat p put!(sum_channel,fetch(s))
277+
end
260278

261-
timer = TimerOutput()
262-
function final_sum(futures,timer)
263-
@timeit timer "fetch" s = fetch(first(futures))
264-
@sync for f in futures[2:end]
265-
@async begin
266-
@timeit timer "fetch" s += fetch(f)
279+
if p==p_final
280+
result = @fetchfrom p_final sum(take!(sum_channel)
281+
for i=1:length(pid_rank0_on_node))
267282
end
268283
end
269-
return s,timer
270284
end
271285

272-
s,timer = @fetchfrom first(procs_used) final_sum(futures,timer)
273-
println(timer)
274-
return s
286+
return result
275287
end
276288

277289
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
@@ -293,10 +305,4 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
293305
return futures
294306
end
295307

296-
function sum_at_node(futures::Vector{Future},hostnames)
297-
myhost = hostnames[worker_rank()]
298-
futures_on_myhost = futures[hostnames .== myhost]
299-
sum(fetch(f) for f in futures_on_myhost)
300-
end
301-
302308
end # module

0 commit comments

Comments
 (0)