Skip to content

Commit ab9fcaf

Browse files
authored
Avoid the zeros(nthreads())[threadid()] buffering pattern (#293)
* Avoid the `zeros(nthreads())[threadid()]` buffering pattern * Fix partitioning
1 parent 3fea924 commit ab9fcaf

File tree

3 files changed

+38
-30
lines changed

3 files changed

+38
-30
lines changed

src/Parallel/centrality/betweenness.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,25 +98,26 @@ function threaded_betweenness_centrality(
9898
k = length(vs)
9999
isdir = is_directed(g)
100100

101-
local_betweenness = [zeros(n_v) for i in 1:nthreads()]
102101
vs_active = findall((x) -> degree(g, x) > 0, vs) # 0 might be 1?
102+
k_active = length(vs_active)
103+
d, r = divrem(k_active, Threads.nthreads())
104+
ntasks = d == 0 ? r : Threads.nthreads()
105+
local_betweenness = [zeros(n_v) for _ in 1:ntasks]
106+
task_size = cld(k_active, ntasks)
103107

104-
Base.Threads.@threads for s in vs_active
105-
state = Graphs.dijkstra_shortest_paths(
106-
g, s, distmx; allpaths=true, trackvertices=true
107-
)
108-
if endpoints
109-
Graphs._accumulate_endpoints!(
110-
local_betweenness[Base.Threads.threadid()], state, g, s
111-
)
112-
else
113-
Graphs._accumulate_basic!(
114-
local_betweenness[Base.Threads.threadid()], state, g, s
108+
@sync for (t, task_range) in enumerate(Iterators.partition(1:k_active, task_size))
109+
Threads.@spawn for s in @view(vs_active[task_range])
110+
state = Graphs.dijkstra_shortest_paths(
111+
g, s, distmx; allpaths=true, trackvertices=true
115112
)
113+
if endpoints
114+
Graphs._accumulate_endpoints!(local_betweenness[t], state, g, s)
115+
else
116+
Graphs._accumulate_basic!(local_betweenness[t], state, g, s)
117+
end
116118
end
117119
end
118120
betweenness = reduce(+, local_betweenness)
119-
120121
Graphs._rescale!(betweenness, n_v, normalize, isdir, k)
121122

122123
return betweenness

src/Parallel/centrality/stress.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,17 @@ function threaded_stress_centrality(g::AbstractGraph, vs=vertices(g))::Vector{In
4444
isdir = is_directed(g)
4545

4646
# Parallel reduction
47-
local_stress = [zeros(Int, n_v) for _ in 1:nthreads()]
47+
d, r = divrem(k, Threads.nthreads())
48+
ntasks = d == 0 ? r : Threads.nthreads()
49+
local_stress = [zeros(Int, n_v) for _ in 1:ntasks]
50+
task_size = cld(k, ntasks)
4851

49-
Base.Threads.@threads for s in vs
50-
if degree(g, s) > 0 # this might be 1?
51-
state = Graphs.dijkstra_shortest_paths(g, s; allpaths=true, trackvertices=true)
52-
Graphs._stress_accumulate_basic!(
53-
local_stress[Base.Threads.threadid()], state, g, s
54-
)
52+
@sync for (t, task_range) in enumerate(Iterators.partition(1:k, task_size))
53+
Threads.@spawn for s in @view(vs[task_range])
54+
if degree(g, s) > 0 # this might be 1?
55+
state = Graphs.dijkstra_shortest_paths(g, s; allpaths=true, trackvertices=true)
56+
Graphs._stress_accumulate_basic!(local_stress[t], state, g, s)
57+
end
5558
end
5659
end
5760
return reduce(+, local_stress)

src/Parallel/utils.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,24 @@ Multi-threaded implementation of [`generate_reduce`](@ref).
4040
function threaded_generate_reduce(
4141
g::AbstractGraph{T}, gen_func::Function, comp::Comp, reps::Integer
4242
) where {T<:Integer,Comp}
43-
n_t = Base.Threads.nthreads()
44-
is_undef = ones(Bool, n_t)
45-
min_set = [Vector{T}() for _ in 1:n_t]
46-
Base.Threads.@threads for _ in 1:reps
47-
t = Base.Threads.threadid()
48-
next_set = gen_func(g)
49-
if is_undef[t] || comp(next_set, min_set[t])
50-
min_set[t] = next_set
51-
is_undef[t] = false
43+
d, r = divrem(reps, Threads.nthreads())
44+
ntasks = d == 0 ? r : Threads.nthreads()
45+
min_set = [Vector{T}() for _ in 1:ntasks]
46+
is_undef = ones(Bool, ntasks)
47+
task_size = cld(reps, ntasks)
48+
49+
@sync for (t, task_range) in enumerate(Iterators.partition(1:reps, task_size))
50+
Threads.@spawn for _ in task_range
51+
next_set = gen_func(g)
52+
if is_undef[t] || comp(next_set, min_set[t])
53+
min_set[t] = next_set
54+
is_undef[t] = false
55+
end
5256
end
5357
end
5458

5559
min_ind = 0
56-
for i in filter((j) -> !is_undef[j], 1:n_t)
60+
for i in filter((j) -> !is_undef[j], 1:ntasks)
5761
if min_ind == 0 || comp(min_set[i], min_set[min_ind])
5862
min_ind = i
5963
end

0 commit comments

Comments
 (0)