Skip to content

Commit 11cdd2a

Browse files
committed
separate functions
1 parent 20930de commit 11cdd2a

File tree

1 file changed

+43
-89
lines changed

1 file changed

+43
-89
lines changed

src/ParallelUtilities.jl

Lines changed: 43 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -242,114 +242,68 @@ end
242242

243243
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
244244

245-
# function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
245+
function pmapsum_remotechannel(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
246246

247-
# procs_used = workers_active(iterable)
248-
249-
# num_workers = length(procs_used);
250-
# hostnames = get_hostnames(procs_used);
251-
# nodes = get_nodes(hostnames);
252-
# pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
253-
254-
# nprocs_node = get_nprocs_node(procs_used)
255-
# node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
256-
# for (node,pid_node) in zip(nodes,pid_rank0_on_node))
257-
258-
# # Worker at which final reduction takes place
259-
# p_final = first(pid_rank0_on_node)
260-
261-
# sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262-
# result = nothing
263-
264-
# # Run the function on each processor and compute the sum at each node
265-
# @sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266-
# @async begin
267-
268-
# node_remotechannel = node_channels[node]
269-
# np_node = nprocs_node[node]
270-
271-
# iterable_on_proc = split_across_processors(iterable,num_workers,rank)
272-
# @spawnat p put!(node_remotechannel,
273-
# f(iterable_on_proc,args...;kwargs...))
274-
275-
# @async if p in pid_rank0_on_node
276-
# f = @spawnat p put!(sum_channel,
277-
# sum(take!(node_remotechannel) for i=1:np_node))
278-
# wait(f)
279-
# @spawnat p finalize(node_remotechannel)
280-
# end
281-
282-
# @async if p==p_final
283-
# result = @fetchfrom p_final sum(take!(sum_channel)
284-
# for i=1:length(pid_rank0_on_node))
285-
# @spawnat p finalize(sum_channel)
286-
# end
287-
# end
288-
# end
289-
290-
# return result :: T
291-
# end
292-
293-
# function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T<:AbstractArray}
294-
295-
# # Use ArrayChannels
296-
297-
# procs_used = workers_active(iterable)
247+
procs_used = workers_active(iterable)
298248

299-
# num_workers = length(procs_used);
300-
# hostnames = get_hostnames(procs_used);
301-
# nodes = get_nodes(hostnames);
302-
# pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
249+
num_workers = length(procs_used);
250+
hostnames = get_hostnames(procs_used);
251+
nodes = get_nodes(hostnames);
252+
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
303253

304-
# nprocs_node = get_nprocs_node(procs_used)
305-
# node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
306-
# for (node,pid_node) in zip(nodes,pid_rank0_on_node))
254+
nprocs_node = get_nprocs_node(procs_used)
255+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
256+
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
307257

308-
# # Worker at which final reduction takes place
309-
# p_final = first(pid_rank0_on_node)
258+
# Worker at which final reduction takes place
259+
p_final = first(pid_rank0_on_node)
310260

311-
# sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
312-
# result = nothing
261+
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262+
result = nothing
313263

314-
# # Run the function on each processor and compute the sum at each node
315-
# @sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
316-
# @async begin
264+
# Run the function on each processor and compute the sum at each node
265+
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
266+
@async begin
317267

318-
# iterable_on_proc = split_across_processors(iterable,num_workers,rank)
319-
# r = @spawnat p f(iterable_on_proc,args...;kwargs...)
320-
321-
# node_remotechannel = node_channels[node]
322-
# np_node = nprocs_node[node]
268+
node_remotechannel = node_channels[node]
269+
np_node = nprocs_node[node]
323270

324-
# @spawnat p put!(node_remotechannel,fetch(r))
325-
326-
# if p in pid_rank0_on_node
327-
# s = @spawnat p sum(take!(node_remotechannel) for i=1:np_node)
328-
# @spawnat p put!(sum_channel,fetch(s))
329-
# end
271+
iterable_on_proc = split_across_processors(iterable,num_workers,rank)
272+
@spawnat p put!(node_remotechannel,
273+
f(iterable_on_proc,args...;kwargs...))
274+
275+
@async if p in pid_rank0_on_node
276+
f = @spawnat p put!(sum_channel,
277+
sum(take!(node_remotechannel) for i=1:np_node))
278+
wait(f)
279+
@spawnat p finalize(node_remotechannel)
280+
end
330281

331-
# if p==p_final
332-
# result = @fetchfrom p_final sum(take!(sum_channel)
333-
# for i=1:length(pid_rank0_on_node))
334-
# end
335-
# end
336-
# end
282+
@async if p==p_final
283+
result = @fetchfrom p_final sum(take!(sum_channel)
284+
for i=1:length(pid_rank0_on_node))
285+
@spawnat p finalize(sum_channel)
286+
end
287+
end
288+
end
337289

338-
# return result :: T
339-
# end
290+
return result :: T
291+
end
340292

341-
# function pmapsum(f::Function,iterable,args...;kwargs...)
342-
# pmapsum(Any,f,iterable,args...;kwargs...)
343-
# end
293+
function pmapsum_remotechannel(f::Function,iterable,args...;kwargs...)
294+
pmapsum(Any,f,iterable,args...;kwargs...)
295+
end
344296

345-
function pmapsum(f::Function,iterable,args...;kwargs...)
297+
function pmapsum_distributedfor(f::Function,iterable,args...;kwargs...)
346298
@distributed (+) for i in 1:nworkers()
347299
np = nworkers_active(iterable)
348300
iter_proc = split_across_processors(iterable,np,i)
349301
f(iter_proc,args...;kwargs...)
350302
end
351303
end
352304

305+
pmapsum(f,args...;kwargs...) = pmapsum_remotechannel(f,args...;kwargs...)
306+
353307
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
354308

355309
procs_used = workers_active(iterable)

0 commit comments

Comments
 (0)