Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/matrices/distance_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ end

# Again, we'll define the serial version first:
function _distancematrix(x::Array_or_SSSet, metric::Metric, ::Val{false})
d = zeros(eltype(x), length(x), length(x))
d = zeros(eltype(eltype(x)), length(x), length(x))
for j in 2:length(x)
for i in 1:j-1 # all else is zero
@inbounds d[i, j] = evaluate(metric, x[i], x[j])
Expand Down
24 changes: 22 additions & 2 deletions src/matrices/recurrence_matrix_low.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,16 @@ function recurrence_matrix(x::Vector_or_SSSet, y::Vector_or_SSSet, metric::Metri
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
# Channel to manage `Array`s to be used in each iteration
nbuffers = Threads.nthreads()
threadchannel = Channel{Int}(nbuffers) # for rows and columns
for i in 1:nbuffers
put!(threadchannel, i)
end

# This is the same logic as the serial function, but parallelized.
Threads.@threads for j in eachindex(y)
threadn = Threads.threadid()
threadn = take!(threadchannel)
nzcol = 0
for i in eachindex(x)
@inbounds if evaluate(metric, x[i], y[j]) ≤ ( (ε isa Real) ? ε : ε[j] )
Expand All @@ -79,7 +85,11 @@ function recurrence_matrix(x::Vector_or_SSSet, y::Vector_or_SSSet, metric::Metri
end
end
append!(colvals[threadn], fill(j, (nzcol,)))
put!(threadchannel, threadn)
end
close(threadchannel)

# merge into one array
finalrows = vcat(rowvals...) # merge into one array
finalcols = vcat(colvals...) # merge into one array
nzvals = fill(true, (length(finalrows),))
Expand All @@ -93,10 +103,16 @@ function recurrence_matrix(x::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
# Channel to manage `Array`s to be used in each iteration
nbuffers = Threads.nthreads()
threadchannel = Channel{Int}(nbuffers) # for rows and columns
for i in 1:nbuffers
put!(threadchannel, i)
end

# This is the same logic as the serial function, but parallelized.
Threads.@threads for k in partition_indices(length(x))
threadn = Threads.threadid()
threadn = take!(threadchannel)
for j in k
nzcol = 0
for i in 1:j
Expand All @@ -107,7 +123,11 @@ function recurrence_matrix(x::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
end
append!(colvals[threadn], fill(j, (nzcol,)))
end
put!(threadchannel, threadn)
end
close(threadchannel)

# merge into one array
finalrows = vcat(rowvals...) # merge into one array
finalcols = vcat(colvals...) # merge into one array
nzvals = fill(true, (length(finalrows),))
Expand Down
Loading