Skip to content

Commit 318f0aa

Browse files
heliosdrmDatseris
andauthored
Fix erroneous use of Threads.nthreads and Threads.threadid (#171)
* try Channel storage with threaded `recurrence_matrix` * add missing line to put! data back into Channel. * restrict versions of StateSpaceSets * fix `_distancematrix` according to StateSpaceSets v2.4 * change channels to keep vector indices instead of vectors * change vcat+splat by reduce+vcat * Update Project.toml --------- Co-authored-by: George Datseris <[email protected]>
1 parent 576b119 commit 318f0aa

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecurrenceAnalysis"
22
uuid = "639c3291-70d9-5ea2-8c5b-839eba1ee399"
33
repo = "https://github.com/JuliaDynamics/RecurrenceAnalysis.jl.git"
4-
version = "2.1.1"
4+
version = "2.1.2"
55

66
[deps]
77
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"

src/matrices/distance_matrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ end
9797

9898
# Again, we'll define the serial version first:
9999
function _distancematrix(x::Array_or_SSSet, metric::Metric, ::Val{false})
100-
d = zeros(eltype(x), length(x), length(x))
100+
d = zeros(eltype(eltype(x)), length(x), length(x))
101101
for j in 2:length(x)
102102
for i in 1:j-1 # all else is zero
103103
@inbounds d[i, j] = evaluate(metric, x[i], x[j])

src/matrices/recurrence_matrix_low.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,16 @@ function recurrence_matrix(x::Vector_or_SSSet, y::Vector_or_SSSet, metric::Metri
6767
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
6868
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
6969
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
70+
# Channel to manage `Array`s to be used in each iteration
71+
nbuffers = Threads.nthreads()
72+
threadchannel = Channel{Int}(nbuffers) # for rows and columns
73+
for i in 1:nbuffers
74+
put!(threadchannel, i)
75+
end
7076

7177
# This is the same logic as the serial function, but parallelized.
7278
Threads.@threads for j in eachindex(y)
73-
threadn = Threads.threadid()
79+
threadn = take!(threadchannel)
7480
nzcol = 0
7581
for i in eachindex(x)
7682
@inbounds if evaluate(metric, x[i], y[j]) ( (ε isa Real) ? ε : ε[j] )
@@ -79,9 +85,12 @@ function recurrence_matrix(x::Vector_or_SSSet, y::Vector_or_SSSet, metric::Metri
7985
end
8086
end
8187
append!(colvals[threadn], fill(j, (nzcol,)))
88+
put!(threadchannel, threadn)
8289
end
83-
finalrows = vcat(rowvals...) # merge into one array
84-
finalcols = vcat(colvals...) # merge into one array
90+
close(threadchannel)
91+
92+
finalrows = reduce(vcat, rowvals) # merge into one array
93+
finalcols = reduce(vcat, colvals) # merge into one array
8594
nzvals = fill(true, (length(finalrows),))
8695
return sparse(finalrows, finalcols, nzvals, length(x), length(y))
8796
end
@@ -93,10 +102,16 @@ function recurrence_matrix(x::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
93102
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
94103
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
95104
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
105+
# Channel to manage `Array`s to be used in each iteration
106+
nbuffers = Threads.nthreads()
107+
threadchannel = Channel{Int}(nbuffers) # for rows and columns
108+
for i in 1:nbuffers
109+
put!(threadchannel, i)
110+
end
96111

97112
# This is the same logic as the serial function, but parallelized.
98113
Threads.@threads for k in partition_indices(length(x))
99-
threadn = Threads.threadid()
114+
threadn = take!(threadchannel)
100115
for j in k
101116
nzcol = 0
102117
for i in 1:j
@@ -107,9 +122,12 @@ function recurrence_matrix(x::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
107122
end
108123
append!(colvals[threadn], fill(j, (nzcol,)))
109124
end
125+
put!(threadchannel, threadn)
110126
end
111-
finalrows = vcat(rowvals...) # merge into one array
112-
finalcols = vcat(colvals...) # merge into one array
127+
close(threadchannel)
128+
129+
finalrows = reduce(vcat, rowvals) # merge into one array
130+
finalcols = reduce(vcat, colvals) # merge into one array
113131
nzvals = fill(true, (length(finalrows),))
114132
return Symmetric(sparse(finalrows, finalcols, nzvals, length(x), length(x)), :U)
115133
end

0 commit comments

Comments
 (0)