|
| 1 | +module ParallelUtilities |
| 2 | + |
| 3 | +using Reexport |
| 4 | +@reexport using Distributed |
| 5 | + |
| 6 | +worker_rank() = myid()-minimum(workers())+1 |
| 7 | + |
| 8 | +function split_across_processors(num_tasks::Integer,num_procs=nworkers(),proc_id=worker_rank()) |
| 9 | + if num_procs == 1 |
| 10 | + return num_tasks |
| 11 | + end |
| 12 | + |
| 13 | + num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs) |
| 14 | + |
| 15 | + num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 ); |
| 16 | + task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id); |
| 17 | + |
| 18 | + return task_start:(task_start+num_tasks_on_proc-1) |
| 19 | +end |
| 20 | + |
| 21 | +function split_across_processors(arr₁,num_procs=nworkers(),proc_id=worker_rank()) |
| 22 | + |
| 23 | + @assert(proc_id<=num_procs,"processor rank has to be less than number of workers engaged") |
| 24 | + if num_procs == 1 |
| 25 | + return arr₁ |
| 26 | + end |
| 27 | + |
| 28 | + num_tasks = length(arr₁); |
| 29 | + |
| 30 | + num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs) |
| 31 | + |
| 32 | + num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 ); |
| 33 | + task_start = num_tasks_per_process*(proc_id-1) + min(num_tasks_leftover+1,proc_id); |
| 34 | + |
| 35 | + return Iterators.take(Iterators.drop(arr₁,task_start-1),num_tasks_on_proc) |
| 36 | +end |
| 37 | + |
| 38 | +function split_product_across_processors(arr₁,arr₂,num_procs::Integer=nworkers(),proc_id::Integer=worker_rank()) |
| 39 | + |
| 40 | + # arr₁ will change faster |
| 41 | + return split_across_processors(Iterators.product(arr₁,arr₂),num_procs,proc_id) |
| 42 | +end |
| 43 | + |
| 44 | +function split_product_across_processors(arrs_tuple,num_procs::Integer=nworkers(),proc_id::Integer=worker_rank()) |
| 45 | + return split_across_processors(Iterators.product(arrs_tuple...),num_procs,proc_id) |
| 46 | +end |
| 47 | + |
| 48 | +function get_processor_id_from_split_array(arr₁,arr₂,(arr₁_value,arr₂_value)::Tuple,num_procs) |
| 49 | + # Find the closest match in arrays |
| 50 | + |
| 51 | + if (arr₁_value ∉ arr₁) || (arr₂_value ∉ arr₂) |
| 52 | + return nothing # invalid |
| 53 | + end |
| 54 | + |
| 55 | + num_tasks = length(arr₁)*length(arr₂); |
| 56 | + |
| 57 | + a1_match_index = searchsortedfirst(arr₁,arr₁_value) |
| 58 | + a2_match_index = searchsortedfirst(arr₂,arr₂_value) |
| 59 | + |
| 60 | + num_tasks_per_process,num_tasks_leftover = div(num_tasks,num_procs),mod(num_tasks,num_procs) |
| 61 | + |
| 62 | + proc_id = 1 |
| 63 | + num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 ); |
| 64 | + total_tasks_till_proc_id = num_tasks_on_proc |
| 65 | + |
| 66 | + task_no = 0 |
| 67 | + |
| 68 | + for (ind2,a2) in enumerate(arr₂), (ind1,a1) in enumerate(arr₁) |
| 69 | + |
| 70 | + task_no +=1 |
| 71 | + if task_no > total_tasks_till_proc_id |
| 72 | + proc_id += 1 |
| 73 | + num_tasks_on_proc = num_tasks_per_process + (proc_id <= mod(num_tasks,num_procs) ? 1 : 0 ); |
| 74 | + total_tasks_till_proc_id += num_tasks_on_proc |
| 75 | + end |
| 76 | + |
| 77 | + if ind2< a2_match_index |
| 78 | + continue |
| 79 | + end |
| 80 | + |
| 81 | + if (ind2 == a2_match_index) && (ind1 == a1_match_index) |
| 82 | + break |
| 83 | + end |
| 84 | + end |
| 85 | + |
| 86 | + return proc_id |
| 87 | +end |
| 88 | + |
| 89 | +function get_processor_range_from_split_array(arr₁,arr₂,modes_on_proc,num_procs) |
| 90 | + |
| 91 | + if isempty(modes_on_proc) |
| 92 | + return 0:-1 # empty range |
| 93 | + end |
| 94 | + |
| 95 | + tasks_arr = collect(modes_on_proc) |
| 96 | + proc_id_start = get_processor_id_from_split_array(arr₁,arr₂,first(tasks_arr),num_procs) |
| 97 | + proc_id_end = get_processor_id_from_split_array(arr₁,arr₂,last(tasks_arr),num_procs) |
| 98 | + return proc_id_start:proc_id_end |
| 99 | +end |
| 100 | + |
| 101 | +function get_index_in_split_array(modes_on_proc,(arr₁_value,arr₂_value)) |
| 102 | + if isnothing(modes_on_proc) |
| 103 | + return nothing |
| 104 | + end |
| 105 | + for (ind,(t1,t2)) in enumerate(modes_on_proc) |
| 106 | + if (t1==arr₁_value) && (t2 == arr₂_value) |
| 107 | + return ind |
| 108 | + end |
| 109 | + end |
| 110 | + nothing |
| 111 | +end |
| 112 | + |
| 113 | +function procid_and_mode_index(arr₁,arr₂,(arr₁_value,arr₂_value),num_procs) |
| 114 | + proc_id_mode = get_processor_id_from_split_array(arr₁,arr₂,(arr₁_value,arr₂_value),num_procs) |
| 115 | + modes_in_procid_file = split_product_across_processors(arr₁,arr₂,num_procs,proc_id_mode) |
| 116 | + mode_index = get_index_in_split_array(modes_in_procid_file,(arr₁_value,arr₂_value)) |
| 117 | + return proc_id_mode,mode_index |
| 118 | +end |
| 119 | + |
| 120 | +function mode_index_in_file(arr₁,arr₂,(arr₁_value,arr₂_value),num_procs,proc_id_mode) |
| 121 | + modes_in_procid_file = split_product_across_processors(arr₁,arr₂,num_procs,proc_id_mode) |
| 122 | + mode_index = get_index_in_split_array(modes_in_procid_file,(arr₁_value,arr₂_value)) |
| 123 | +end |
| 124 | + |
| 125 | +function procid_allmodes(arr₁,arr₂,iter,num_procs=nworkers_active(arr₁,arr₂)) |
| 126 | + procid = zeros(Int64,length(iter)) |
| 127 | + for (ind,mode) in enumerate(iter) |
| 128 | + procid[ind] = get_processor_id_from_split_array(arr₁,arr₂,mode,num_procs) |
| 129 | + end |
| 130 | + return procid |
| 131 | +end |
| 132 | + |
| 133 | +workers_active(arr) = workers()[1:min(length(arr),nworkers())] |
| 134 | + |
| 135 | +workers_active(arr₁,arr₂) = workers_active(Iterators.product(arr₁,arr₂)) |
| 136 | + |
| 137 | +nworkers_active(args...) = length(workers_active(args...)) |
| 138 | + |
| 139 | +function minmax_from_split_array(iterable) |
| 140 | + arr₁_min,arr₂_min = first(iterable) |
| 141 | + arr₁_max,arr₂_max = arr₁_min,arr₂_min |
| 142 | + for (arr₁_value,arr₂_value) in iterable |
| 143 | + arr₁_min = min(arr₁_min,arr₁_value) |
| 144 | + arr₁_max = max(arr₁_max,arr₁_value) |
| 145 | + arr₂_min = min(arr₂_min,arr₂_value) |
| 146 | + arr₂_max = max(arr₂_max,arr₂_value) |
| 147 | + end |
| 148 | + return (arr₁_min=arr₁_min,arr₁_max=arr₁_max,arr₂_min=arr₂_min,arr₂_max=arr₂_max) |
| 149 | +end |
| 150 | + |
| 151 | +function get_hostnames(procs_used=workers()) |
| 152 | + hostnames = Vector{String}(undef,length(procs_used)) |
| 153 | + @sync for (ind,p) in enumerate(procs_used) |
| 154 | + @async hostnames[ind] = @fetchfrom p Libc.gethostname() |
| 155 | + end |
| 156 | + return hostnames |
| 157 | +end |
| 158 | + |
| 159 | +get_nodes(hostnames::Vector{String}) = unique(hostnames) |
| 160 | +get_nodes(procs_used::Vector{<:Integer}=workers()) = get_nodes(get_hostnames(procs_used)) |
| 161 | + |
| 162 | +function get_nprocs_node(hostnames::Vector{String}) |
| 163 | + nodes = get_nodes(hostnames) |
| 164 | + num_procs_node = Dict(node=>count(x->x==node,hostnames) for node in nodes) |
| 165 | +end |
| 166 | + |
| 167 | +get_nprocs_node(procs_used::Vector{<:Integer}=workers()) = get_nprocs_node(get_hostnames(procs_used)) |
| 168 | + |
| 169 | +function pmapsum(f::Function,iterable,args...;kwargs...) where {T} |
| 170 | + |
| 171 | + procs_used = workers_active(iterable) |
| 172 | + num_workers = length(procs_used) |
| 173 | + hostnames = get_hostnames(procs_used) |
| 174 | + nodes = get_nodes(hostnames) |
| 175 | + pid_rank0_on_node = [procs_used[findfirst(x->x==node,hostnames)] for node in nodes] |
| 176 | + |
| 177 | + futures = pmap_onebatch_per_worker(f,iterable,args...;kwargs...) |
| 178 | + |
| 179 | + # Intermediate sum over processors on the same node |
| 180 | + node_sum_futures = Vector{Future}(undef,length(pid_rank0_on_node)) |
| 181 | + @sync for (ind,p) in enumerate(pid_rank0_on_node) |
| 182 | + @async node_sum_futures[ind] = @spawnat p sum_at_node(futures,hostnames) |
| 183 | + end |
| 184 | + |
| 185 | + # Worker at which final reduction takes place |
| 186 | + p = first(pid_rank0_on_node) |
| 187 | + |
| 188 | + # Final sum across all nodes |
| 189 | + S = @fetchfrom p sum(fetch(f) for f in node_sum_futures) |
| 190 | + |
| 191 | + return S |
| 192 | +end |
| 193 | + |
| 194 | +function pmap_onebatch_per_worker(f::Function,iterable,args...;num_workers=nothing,kwargs...) |
| 195 | + |
| 196 | + procs_used = workers_active(iterable) |
| 197 | + if !isnothing(num_workers) && num_workers<=length(procs_used) |
| 198 | + procs_used = procs_used[1:num_workers] |
| 199 | + end |
| 200 | + num_workers = length(procs_used) |
| 201 | + |
| 202 | + futures = Vector{Future}(undef,num_workers) |
| 203 | + @sync for (rank,p) in enumerate(procs_used) |
| 204 | + @async begin |
| 205 | + iterable_on_proc = split_across_processors(iterable,num_workers,rank) |
| 206 | + futures[rank] = @spawnat p f(iterable_on_proc,args...;kwargs...) |
| 207 | + end |
| 208 | + end |
| 209 | + return futures |
| 210 | +end |
| 211 | + |
| 212 | +function sum_at_node(futures::Vector{Future},hostnames) |
| 213 | + myhost = hostnames[worker_rank()] |
| 214 | + futures_on_myhost = futures[hostnames .== myhost] |
| 215 | + sum(fetch(f) for f in futures_on_myhost) |
| 216 | +end |
| 217 | + |
| 218 | +############################################################################# |
| 219 | + |
| 220 | +export split_across_processors,split_product_across_processors, |
| 221 | +get_processor_id_from_split_array, |
| 222 | +procid_allmodes,mode_index_in_file, |
| 223 | +get_processor_range_from_split_array,workers_active,worker_rank, |
| 224 | +get_index_in_split_array,procid_and_mode_index,minmax_from_split_array, |
| 225 | +node_remotechannels,pmapsum,sum_at_node,pmap_onebatch_per_worker, |
| 226 | +get_nodes,get_hostnames,get_nprocs_node |
| 227 | + |
| 228 | +end # module |
0 commit comments