@@ -157,7 +157,11 @@ get_nodes(procs_used::Vector{<:Integer}=workers()) = get_nodes(get_hostnames(pro
157
157
158
158
function get_nprocs_node (hostnames:: Vector{String} )
159
159
nodes = get_nodes (hostnames)
160
- num_procs_node = Dict (node=> count (x-> x== node,hostnames) for node in nodes)
160
+ get_nprocs_node (hostnames,nodes)
161
+ end
162
+
163
+ function get_nprocs_node (hostnames:: Vector{String} ,nodes:: Vector{String} )
164
+ Dict (node=> count (isequal (node),hostnames) for node in nodes)
161
165
end
162
166
163
167
get_nprocs_node (procs_used:: Vector{<:Integer} = workers ()) = get_nprocs_node (get_hostnames (procs_used))
@@ -168,21 +172,48 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
168
172
num_workers = length (procs_used)
169
173
hostnames = get_hostnames (procs_used)
170
174
nodes = get_nodes (hostnames)
171
- pid_rank0_on_node = [procs_used[findfirst (x-> x== node,hostnames)] for node in nodes]
172
-
173
- futures = pmap_onebatch_per_worker (f,iterable,args... ;kwargs... )
175
+ np_nodes = get_nprocs_node (hostnames)
176
+ pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes]
174
177
175
- # Intermediate sum over processors on the same node
176
- node_sum_futures = Vector {Future} (undef,length (pid_rank0_on_node))
177
- @sync for (ind,p) in enumerate (pid_rank0_on_node)
178
- @async node_sum_futures[ind] = @spawnat p sum_at_node (futures,hostnames)
178
+ function apply_and_stash (f,iterable,args... ;channel,kwargs... )
179
+ result = f (iterable,args... ;kwargs... )
180
+ put! (channel,result)
179
181
end
180
182
181
183
# Worker at which final reduction takes place
182
- p = first (pid_rank0_on_node)
184
+ p_final = first (pid_rank0_on_node)
185
+ sum_channel = RemoteChannel (()-> Channel {Any} (10 ),p_final)
186
+ master_channel_nodes = Dict (node=> RemoteChannel (()-> Channel {Any} (10 ),p)
187
+ for (node,p) in zip (nodes,pid_rank0_on_node))
188
+ worker_channel_nodes = Dict (node=> RemoteChannel (()-> Channel {Any} (10 ),p)
189
+ for (node,p) in zip (nodes,pid_rank0_on_node))
190
+
191
+ # Run the function on each processor
192
+ @sync for (rank,p) in enumerate (procs_used)
193
+ @async begin
194
+ iterable_on_proc = split_across_processors (iterable,num_workers,rank)
195
+ node = hostnames[rank]
196
+ master_channel_node = master_channel_nodes[node]
197
+ worker_channel_node = worker_channel_nodes[node]
198
+ np_node = np_nodes[node]
199
+ if p in pid_rank0_on_node
200
+ @spawnat p apply_and_stash (f,iterable_on_proc,args... ;kwargs... ,
201
+ channel= master_channel_node)
202
+ @spawnat p sum_channel (worker_channel_node,master_channel_node,np_node)
203
+ else
204
+ @spawnat p apply_and_stash (f,iterable_on_proc,args... ;kwargs... ,
205
+ channel= worker_channel_node)
206
+ end
207
+ end
208
+ @async begin
209
+ @spawnat p_final sum_channel (master_channel_nodes,sum_channel,length (nodes))
210
+ end
211
+ end
183
212
184
- # Final sum across all nodes
185
- @fetchfrom p sum (fetch (f) for f in node_sum_futures)
213
+ finalize .(values (worker_channel_nodes))
214
+ finalize .(values (master_channel_nodes))
215
+
216
+ take! (sum_channel)
186
217
end
187
218
188
219
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;num_workers= nothing ,kwargs... )
@@ -203,10 +234,22 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothi
203
234
return futures
204
235
end
205
236
206
- function sum_at_node (futures:: Vector{Future} ,hostnames)
207
- myhost = hostnames[worker_rank ()]
208
- futures_on_myhost = futures[hostnames .== myhost]
209
- sum (fetch (f) for f in futures_on_myhost)
237
+ function spawnf (f:: Function ,iterable)
238
+
239
+ futures = Vector {Future} (undef,nworkers ())
240
+ @sync for (rank,p) in enumerate (workers ())
241
+ futures[rank] = @spawnat p f (iterable)
242
+ end
243
+ return futures
244
+ end
245
+
246
+ function sum_channel (worker_channel,master_channel,np_node)
247
+ @sync for i in 1 : np_node- 1
248
+ @async begin
249
+ s = take! (worker_channel) + take! (master_channel)
250
+ put! (master_channel,s)
251
+ end
252
+ end
210
253
end
211
254
212
255
# ############################################################################
0 commit comments