Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions src/matrices/recurrence_matrix_low.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,74 +62,74 @@ end
# Core function
function recurrence_matrix(x::Vector_or_SSSet, y::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
@assert ε isa Real || length(ε) == length(y)
# We create a Channel for `Array`s of `Array`s, for each thread to have its
# We create an `Array` of `Array`s, for each thread to have its
# own array to push to. This avoids race conditions with
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
# Channel to manage `Array`s to be used in each iteration
nbuffers = Threads.nthreads()
threadchannel = Channel{NTuple{2, Vector{Int}}}(nbuffers) # for rows and columns
foreach(1:nbuffers) do _
put!(threadchannel, (Int[], Int[]))
threadchannel = Channel{Int}(nbuffers) # for rows and columns
for i in 1:nbuffers
put!(threadchannel, i)
end

# This is the same logic as the serial function, but parallelized.
Threads.@threads for j in eachindex(y)
rowvals, colvals = take!(threadchannel)
threadn = take!(threadchannel)
nzcol = 0
for i in eachindex(x)
@inbounds if evaluate(metric, x[i], y[j]) ≤ ( (ε isa Real) ? ε : ε[j] )
push!(rowvals, i) # push to the thread-specific row array
push!(rowvals[threadn], i) # push to the thread-specific row array
nzcol += 1
end
end
append!(colvals, fill(j, (nzcol,)))
put!(threadchannel, (rowvals, colvals))
append!(colvals[threadn], fill(j, (nzcol,)))
put!(threadchannel, threadn)
end
close(threadchannel)

# merge into one array
finalrows = Int[]
finalcols = Int[]
foreach(1:nbuffers) do _
rowvals, colvals = take!(threadchannel)
append!(finalrows, rowvals)
append!(finalcols, colvals)
end
finalrows = vcat(rowvals...) # merge into one array
finalcols = vcat(colvals...) # merge into one array
nzvals = fill(true, (length(finalrows),))
return sparse(finalrows, finalcols, nzvals, length(x), length(y))
end

function recurrence_matrix(x::Vector_or_SSSet, metric::Metric, ε, ::Val{true})
@assert ε isa Real || length(ε) == length(x)
# We create a Channel for `Array`s of `Array`s, for each thread to have its
# We create an `Array` of `Array`s, for each thread to have its
# own array to push to. This avoids race conditions with
# multiple threads pushing to the same `Array` (`Array`s are not atomic).
rowvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
colvals = [Vector{Int}() for _ in 1:Threads.nthreads()]
# Channel to manage `Array`s to be used in each iteration
nbuffers = Threads.nthreads()
threadchannel = Channel{NTuple{2, Vector{Int}}}(nbuffers) # for rows and columns
foreach(1:nbuffers) do _
put!(threadchannel, (Int[], Int[]))
threadchannel = Channel{Int}(nbuffers) # for rows and columns
for i in 1:nbuffers
put!(threadchannel, i)
end

# This is the same logic as the serial function, but parallelized.
Threads.@threads for k in partition_indices(length(x))
rowvals, colvals = take!(threadchannel)
threadn = take!(threadchannel)
for j in k
nzcol = 0
for i in 1:j
@inbounds if evaluate(metric, x[i], x[j]) ≤ ( (ε isa Real) ? ε : ε[j] )
push!(rowvals, i) # push to the thread-specific row array
push!(rowvals[threadn], i) # push to the thread-specific row array
nzcol += 1
end
end
append!(colvals, fill(j, (nzcol,)))
append!(colvals[threadn], fill(j, (nzcol,)))
end
put!(threadchannel, (rowvals, colvals))
put!(threadchannel, threadn)
end
close(threadchannel)

# merge into one array
finalrows = Int[]
finalcols = Int[]
foreach(1:nbuffers) do _
rowvals, colvals = take!(threadchannel)
append!(finalrows, rowvals)
append!(finalcols, colvals)
end
finalrows = vcat(rowvals...) # merge into one array
finalcols = vcat(colvals...) # merge into one array
nzvals = fill(true, (length(finalrows),))
return Symmetric(sparse(finalrows, finalcols, nzvals, length(x), length(x)), :U)
end
Loading