Skip to content

Commit f308d03

Browse files
committed
remotechannel type
1 parent bd680fc commit f308d03

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/ParallelUtilities.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ end
242242

243243
get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used))
244244

245-
function pmapsum(f::Function,iterable,args...;kwargs...)
245+
function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
246246

247247
procs_used = workers_active(iterable)
248248

@@ -252,13 +252,13 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
252252
pid_rank0_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
253253

254254
nprocs_node = get_nprocs_node(procs_used)
255-
node_channels = Dict(node=>RemoteChannel(()->Channel{Any}(nprocs_node[node]),pid_node)
255+
node_channels = Dict(node=>RemoteChannel(()->Channel{T}(nprocs_node[node]),pid_node)
256256
for (node,pid_node) in zip(nodes,pid_rank0_on_node))
257257

258258
# Worker at which final reduction takes place
259259
p_final = first(pid_rank0_on_node)
260260

261-
sum_channel = RemoteChannel(()->Channel{Any}(length(pid_rank0_on_node)),p_final)
261+
sum_channel = RemoteChannel(()->Channel{T}(length(pid_rank0_on_node)),p_final)
262262
result = nothing
263263

264264
# Run the function on each processor and compute the sum at each node
@@ -288,6 +288,11 @@ function pmapsum(f::Function,iterable,args...;kwargs...)
288288
return result
289289
end
290290

291+
function pmapsum(f::Function,iterable,args...;kwargs...)
292+
T = promote_type(Base.return_types(f)...)
293+
pmapsum(T,f,iterable,args...;kwargs...)
294+
end
295+
291296
function pmap_onebatch_per_worker(f::Function,iterable,args...;kwargs...)
292297

293298
procs_used = workers_active(iterable)

0 commit comments

Comments
 (0)