Skip to content

Commit ec4e77e

Browse files
committed
distributed for in pmapsum
1 parent 4cee9bd commit ec4e77e

File tree

1 file changed

+80
-72
lines changed

1 file changed

+80
-72
lines changed

src/ParallelUtilities.jl

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -242,104 +242,112 @@ end
242242

243243
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
244244

245-
function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
245+
# function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
246246

247-
procs_used = workers_active(iterable)
247+
# procs_used = workers_active(iterable)
248248

249-
num_workers = length(procs_used);
250-
hostnames = get_hostnames(procs_used);
251-
nodes = get_nodes(hostnames);
252-
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
249+
# num_workers = length(procs_used);
250+
# hostnames = get_hostnames(procs_used);
251+
# nodes = get_nodes(hostnames);
252+
# pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
253253

254-
nprocs_node = get_nprocs_node(procs_used)
255-
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
256-
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
254+
# nprocs_node = get_nprocs_node(procs_used)
255+
# node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
256+
# for (node,pid_node) in zip(nodes,pid_rank0_on_node))
257257

258-
# Worker at which final reduction takes place
259-
p_final = first(pid_rank0_on_node)
258+
# # Worker at which final reduction takes place
259+
# p_final = first(pid_rank0_on_node)
260260

261-
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262-
result = nothing
261+
# sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262+
# result = nothing
263263

264-
# Run the function on each processor and compute the sum at each node
265-
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266-
@async begin
264+
# # Run the function on each processor and compute the sum at each node
265+
# @sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266+
# @async begin
267267

268-
node_remotechannel = node_channels[node]
269-
np_node = nprocs_node[node]
268+
# node_remotechannel = node_channels[node]
269+
# np_node = nprocs_node[node]
270270

271-
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
272-
@spawnat p put!(node_remotechannel,
273-
f(iterable_on_proc,args...;kwargs...))
274-
275-
@async if p in pid_rank0_on_node
276-
f = @spawnat p put!(sum_channel,
277-
sum(take!(node_remotechannel) for i=1:np_node))
278-
wait(f)
279-
@spawnat p finalize(node_remotechannel)
280-
end
271+
# iterable_on_proc = split_across_processors(iterable,num_workers,rank)
272+
# @spawnat p put!(node_remotechannel,
273+
# f(iterable_on_proc,args...;kwargs...))
281274

282-
@async if p==p_final
283-
result = @fetchfrom p_final sum(take!(sum_channel)
284-
for i=1:length(pid_rank0_on_node))
285-
@spawnat p finalize(sum_channel)
286-
end
287-
end
288-
end
275+
# @async if p in pid_rank0_on_node
276+
# f = @spawnat p put!(sum_channel,
277+
# sum(take!(node_remotechannel) for i=1:np_node))
278+
# wait(f)
279+
# @spawnat p finalize(node_remotechannel)
280+
# end
289281

290-
return result :: T
291-
end
282+
# @async if p==p_final
283+
# result = @fetchfrom p_final sum(take!(sum_channel)
284+
# for i=1:length(pid_rank0_on_node))
285+
# @spawnat p finalize(sum_channel)
286+
# end
287+
# end
288+
# end
292289

293-
function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T<:AbstractArray}
290+
# return result :: T
291+
# end
294292

295-
# Use ArrayChannels
293+
# function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T<:AbstractArray}
296294

297-
procs_used = workers_active(iterable)
295+
# # Use ArrayChannels
298296

299-
num_workers = length(procs_used);
300-
hostnames = get_hostnames(procs_used);
301-
nodes = get_nodes(hostnames);
302-
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
297+
# procs_used = workers_active(iterable)
303298

304-
nprocs_node = get_nprocs_node(procs_used)
305-
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
306-
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
299+
# num_workers = length(procs_used);
300+
# hostnames = get_hostnames(procs_used);
301+
# nodes = get_nodes(hostnames);
302+
# pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
307303

308-
# Worker at which final reduction takes place
309-
p_final = first(pid_rank0_on_node)
304+
# nprocs_node = get_nprocs_node(procs_used)
305+
# node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
306+
# for (node,pid_node) in zip(nodes,pid_rank0_on_node))
310307

311-
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
312-
result = nothing
308+
# # Worker at which final reduction takes place
309+
# p_final = first(pid_rank0_on_node)
313310

314-
# Run the function on each processor and compute the sum at each node
315-
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
316-
@async begin
311+
# sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
312+
# result = nothing
313+
314+
# # Run the function on each processor and compute the sum at each node
315+
# @sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
316+
# @async begin
317317

318-
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
319-
r = @spawnat p f(iterable_on_proc,args...;kwargs...)
318+
# iterable_on_proc = split_across_processors(iterable,num_workers,rank)
319+
# r = @spawnat p f(iterable_on_proc,args...;kwargs...)
320320

321-
node_remotechannel = node_channels[node]
322-
np_node = nprocs_node[node]
321+
# node_remotechannel = node_channels[node]
322+
# np_node = nprocs_node[node]
323323

324-
@spawnat p put!(node_remotechannel,fetch(r))
324+
# @spawnat p put!(node_remotechannel,fetch(r))
325325

326-
if p in pid_rank0_on_node
327-
s = @spawnat p sum(take!(node_remotechannel) for i=1:np_node)
328-
@spawnat p put!(sum_channel,fetch(s))
329-
end
326+
# if p in pid_rank0_on_node
327+
# s = @spawnat p sum(take!(node_remotechannel) for i=1:np_node)
328+
# @spawnat p put!(sum_channel,fetch(s))
329+
# end
330330

331-
if p==p_final
332-
result = @fetchfrom p_final sum(take!(sum_channel)
333-
for i=1:length(pid_rank0_on_node))
334-
end
335-
end
336-
end
331+
# if p==p_final
332+
# result = @fetchfrom p_final sum(take!(sum_channel)
333+
# for i=1:length(pid_rank0_on_node))
334+
# end
335+
# end
336+
# end
337337

338-
return result :: T
339-
end
338+
# return result :: T
339+
# end
340+
341+
# function pmapsum(f::Function,iterable,args...;kwargs...)
342+
# pmapsum(Any,f,iterable,args...;kwargs...)
343+
# end
340344

341345
function pmapsum(f::Function,iterable,args...;kwargs...)
342-
pmapsum(Any,f,iterable,args...;kwargs...)
346+
@distributed (+) for i in 1:nworkers()
347+
np = workers_active(iterable)
348+
iter_proc = split_across_processors(iterable,np,i)
349+
f(iter_proc,args...;kwargs...)
350+
end
343351
end
344352

345353
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)

0 commit comments

Comments
 (0)