242
242
243
243
get_nprocs_node (procs_used:: Vector{<:Integer} = workers ()) = get_nprocs_node (get_hostnames (procs_used))
244
244
245
- function pmapsum (f:: Function ,iterable,args... ;kwargs... )
245
+ function pmapsum (:: Type{T} , f:: Function ,iterable,args... ;kwargs... ) where {T}
246
246
247
247
procs_used = workers_active (iterable)
248
248
@@ -252,13 +252,13 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
252
252
pid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
253
253
254
254
nprocs_node = get_nprocs_node (procs_used)
255
- node_channels = Dict (node=> RemoteChannel (()-> Channel {Any } (nprocs_node[node]),pid_node)
255
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {T } (nprocs_node[node]),pid_node)
256
256
for (node,pid_node) in zip (nodes,pid_rank0_on_node))
257
257
258
258
# Worker at which final reduction takes place
259
259
p_final = first (pid_rank0_on_node)
260
260
261
- sum_channel = RemoteChannel (()-> Channel {Any } (length (pid_rank0_on_node)),p_final)
261
+ sum_channel = RemoteChannel (()-> Channel {T } (length (pid_rank0_on_node)),p_final)
262
262
result = nothing
263
263
264
264
# Run the function on each processor and compute the sum at each node
@@ -288,6 +288,11 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
288
288
return result
289
289
end
290
290
291
+ function pmapsum (f:: Function ,iterable,args... ;kwargs... )
292
+ T = promote_type (Base. return_types (f)... )
293
+ pmapsum (T,f,iterable,args... ;kwargs... )
294
+ end
295
+
291
296
function pmap_onebatch_per_worker (f:: Function ,iterable,args... ;kwargs... )
292
297
293
298
procs_used = workers_active (iterable)
0 commit comments