Skip to content

Commit 4718f05

Browse files
committed
test1
1 parent 4b17b8f commit 4718f05

File tree

1 file changed

+58
-15
lines changed

1 file changed

+58
-15
lines changed

src/ParallelUtilities.jl

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ get_nodes(procs_used::Vector{<:Integer}=workers()) = get_nodes(get_hostnames(pro
157157

158158
function get_nprocs_node(hostnames::Vector{String})
159159
nodes = get_nodes(hostnames)
160-
num_procs_node = Dict(node=>count(x->x==node,hostnames) for node in nodes)
160+
get_nprocs_node(hostnames,nodes)
161+
end
162+
163+
function get_nprocs_node(hostnames::Vector{String},nodes::Vector{String})
164+
Dict(node=>count(isequal(node),hostnames) for node in nodes)
161165
end
162166

163167
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
@@ -168,21 +172,48 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
168172
num_workers = length(procs_used)
169173
hostnames = get_hostnames(procs_used)
170174
nodes = get_nodes(hostnames)
171-
pid_rank0_on_node = [procs_used[findfirst(x->x==node,hostnames)] for node in nodes]
172-
173-
futures = pmap_onebatch_per_worker(f,iterable,args...;kwargs...)
175+
np_nodes = get_nprocs_node(hostnames)
176+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes]
174177

175-
# Intermediate sum over processors on the same node
176-
node_sum_futures = Vector{Future}(undef,length(pid_rank0_on_node))
177-
@sync for (ind,p) in enumerate(pid_rank0_on_node)
178-
@async node_sum_futures[ind] = @spawnat p sum_at_node(futures,hostnames)
178+
function apply_and_stash(f,iterable,args...;channel,kwargs...)
179+
result = f(iterable,args...;kwargs...)
180+
put!(channel,result)
179181
end
180182

181183
# Worker at which final reduction takes place
182-
p = first(pid_rank0_on_node)
184+
p_final = first(pid_rank0_on_node)
185+
sum_channel = RemoteChannel(()->Channel{Any}(10),p_final)
186+
master_channel_nodes = Dict(node=>RemoteChannel(()->Channel{Any}(10),p)
187+
for (node,p) in zip(nodes,pid_rank0_on_node))
188+
worker_channel_nodes = Dict(node=>RemoteChannel(()->Channel{Any}(10),p)
189+
for (node,p) in zip(nodes,pid_rank0_on_node))
190+
191+
# Run the function on each processor
192+
@sync for (rank,p) in enumerate(procs_used)
193+
@async begin
194+
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
195+
node = hostnames[rank]
196+
master_channel_node = master_channel_nodes[node]
197+
worker_channel_node = worker_channel_nodes[node]
198+
np_node = np_nodes[node]
199+
if p in pid_rank0_on_node
200+
@spawnat p apply_and_stash(f,iterable_on_proc,args...;kwargs...,
201+
channel=master_channel_node)
202+
@spawnat p sum_channel(worker_channel_node,master_channel_node,np_node)
203+
else
204+
@spawnat p apply_and_stash(f,iterable_on_proc,args...;kwargs...,
205+
channel=worker_channel_node)
206+
end
207+
end
208+
@async begin
209+
@spawnat p_final sum_channel(master_channel_nodes,sum_channel,length(nodes))
210+
end
211+
end
183212

184-
# Final sum across all nodes
185-
@fetchfrom p sum(fetch(f) for f in node_sum_futures)
213+
finalize.(values(worker_channel_nodes))
214+
finalize.(values(master_channel_nodes))
215+
216+
take!(sum_channel)
186217
end
187218

188219
function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothing,kwargs...)
@@ -203,10 +234,22 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothi
203234
return futures
204235
end
205236

206-
function sum_at_node(futures::Vector{Future},hostnames)
207-
myhost = hostnames[worker_rank()]
208-
futures_on_myhost = futures[hostnames .== myhost]
209-
sum(fetch(f) for f in futures_on_myhost)
237+
function spawnf(f::Function,iterable)
238+
239+
futures = Vector{Future}(undef,nworkers())
240+
@sync for (rank,p) in enumerate(workers())
241+
futures[rank] = @spawnat p f(iterable)
242+
end
243+
return futures
244+
end
245+
246+
function sum_channel(worker_channel,master_channel,np_node)
247+
@sync for i in 1:np_node-1
248+
@async begin
249+
s = take!(worker_channel) + take!(master_channel)
250+
put!(master_channel,s)
251+
end
252+
end
210253
end
211254

212255
#############################################################################

0 commit comments

Comments
 (0)