diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6aa4796..c0eb071 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -7,7 +7,7 @@ using StaticArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree -export knn, nn, inrange # TODOs? , allpairs, distmat, npairs +export knn, nn, inrange, knn_threaded # TODOs? , allpairs, distmat, npairs export injectdata export Euclidean, diff --git a/src/knn.jl b/src/knn.jl index 6fb4bb1..1adfee9 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -18,14 +18,48 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F= check_input(tree, points) check_k(tree, k) n_points = length(points) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] + idxs = [Vector{Int}(undef, k) for _ in 1:n_points] for i in 1:n_points knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) end return idxs, dists end + +""" + knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances + nn(tree:NNTree, points) -> indices, distances + + Performs a lookup of the `k` nearest neigbours to the `points` from the data + in the `tree`. If `sortres = true` the result is sorted such that the results are + in the order of increasing distance to the point. `skip` is an optional predicate + to determine if a point that would be returned should be skipped based on its + index. + + The keyword argument `n_tasks` determines how batches will be made from the inputs. The + batches are distributed on the available threads, determined by `Threads.nthreads()`. + See `https://docs.julialang.org/en/v1/manual/multi-threading` for help on how to make + Julia aware of available threads. + + Multithreading can significantly slow down other processes on your computer. + To avoid multithreading, set `n_tasks=1` +""" +function knn_threaded(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: AbstractVector, F<:Function} + check_input(tree, points) + check_k(tree, k) + n_points = length(points) + dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] + idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + idxs_batched = _batched_inds(points, n_tasks) + Threads.@threads for inds in idxs_batched + for i in inds + knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) + end + end + return idxs, dists +end + function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} fill!(idx, -1) fill!(dist, typemax(get_T(eltype(V)))) @@ -58,6 +92,17 @@ function knn(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, s knn(tree, new_data, k, sortres, skip) end +function knn_threaded(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: Number, F<:Function} + dim = size(point, 1) + npoints = size(point, 2) + if isbitstype(T) + new_data = copy_svec(T, point, Val(dim)) + else + new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints] + end + knn_threaded(tree, new_data, k, sortres, skip; n_tasks) +end + nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> first nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _firsteach nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _firsteach diff --git a/src/utilities.jl b/src/utilities.jl index e99affa..2c50d1d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -95,3 +95,23 @@ end # Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} = [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)] + + +""" + _batch(v::AbstractVector, n_batches::Int) + +Compute `n_batches` batches from the input vector `v`. +The number of elements in each batch is not even if `length(v) ÷ n_batches != length(v) / n_batches`. +Returns a tuple with (indices, batched_v) +""" +function _batched_inds(v::AbstractVector, n_batches::Int) + @assert length(v) ≥ n_batches "Trying to make $n_batches batches from $(length(v)) elements. This would result in empty arrays of type `Any`, which is likely to cause problems." + divs, rems = divrem(length(v), n_batches) + batchlengths = fill(divs, n_batches) + batchlengths[end-rems+1:end] .+= 1 + + cumsums = pushfirst!(cumsum(batchlengths), 0) + indices = [cumsums[i]+1:cumsums[i+1] for i in 1:n_batches] + + return indices +end