@@ -172,48 +172,27 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
172
172
num_workers = length (procs_used)
173
173
hostnames = get_hostnames (procs_used)
174
174
nodes = get_nodes (hostnames)
175
- np_nodes = get_nprocs_node (hostnames)
176
175
pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes]
177
176
178
- function apply_and_stash (f,iterable,args... ;channel,kwargs... )
179
- result = f (iterable,args... ;kwargs... )
180
- put! (channel,result)
181
- end
182
-
183
177
# Worker at which final reduction takes place
184
178
p_final = first (pid_rank0_on_node)
185
- sum_channel = RemoteChannel (()-> Channel {Any} (10 ),p_final)
186
- master_channel_nodes = Dict (node=> RemoteChannel (()-> Channel {Any} (10 ),p)
187
- for (node,p) in zip (nodes,pid_rank0_on_node))
188
- worker_channel_nodes = Dict (node=> RemoteChannel (()-> Channel {Any} (10 ),p)
189
- for (node,p) in zip (nodes,pid_rank0_on_node))
190
179
191
- # Run the function on each processor
180
+ sum_channel = RemoteChannel (()-> Channel {Any} (100 ),p_final)
181
+ result = 0
182
+
183
+ # Run the function on each processor and compute the sum at each node
192
184
@sync for (rank,p) in enumerate (procs_used)
193
185
@async begin
194
186
iterable_on_proc = split_across_processors (iterable,num_workers,rank)
195
- node = hostnames[rank]
196
- master_channel_node = master_channel_nodes[node]
197
- worker_channel_node = worker_channel_nodes[node]
198
- np_node = np_nodes[node]
199
- if p in pid_rank0_on_node
200
- @spawnat p apply_and_stash (f,iterable_on_proc,args... ;kwargs... ,
201
- channel= master_channel_node)
202
- @spawnat p sum_channel (worker_channel_node,master_channel_node,np_node)
203
- else
204
- @spawnat p apply_and_stash (f,iterable_on_proc,args... ;kwargs... ,
205
- channel= worker_channel_node)
187
+ r = @spawnat p f (iterable_on_proc,args... ;kwargs... )
188
+ @spawnat p put! (sum_channel,fetch (r))
189
+ if p== p_final
190
+ result = @fetchfrom p_final sum (take! (sum_channel) for i= 1 : num_workers)
206
191
end
207
192
end
208
- @async begin
209
- @spawnat p_final sum_channel (master_channel_nodes,sum_channel,length (nodes))
210
- end
211
193
end
212
194
213
- finalize .(values (worker_channel_nodes))
214
- finalize .(values (master_channel_nodes))
215
-
216
- take! (sum_channel)
195
+ return result
217
196
end
218
197
219
198
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;num_workers= nothing ,kwargs... )
@@ -234,24 +213,6 @@ function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothi
234
213
return futures
235
214
end
236
215
237
- function spawnf (f:: Function ,iterable)
238
-
239
- futures = Vector {Future} (undef,nworkers ())
240
- @sync for (rank,p) in enumerate (workers ())
241
- futures[rank] = @spawnat p f (iterable)
242
- end
243
- return futures
244
- end
245
-
246
- function sum_channel (worker_channel,master_channel,np_node)
247
- @sync for i in 1 : np_node- 1
248
- @async begin
249
- s = take! (worker_channel) + take! (master_channel)
250
- put! (master_channel,s)
251
- end
252
- end
253
- end
254
-
255
216
# ############################################################################
256
217
257
218
export split_across_processors,split_product_across_processors,
0 commit comments