Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@

## Creating a Tree

There are currently three types of trees available:
There are currently four types of trees available:

* `KDTree`: Recursively splits points into groups using hyper-planes.
* `BallTree`: Recursively splits points into groups bounded by hyper-spheres.
* `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner.
* `KDTree`: Recursively splits points into groups using hyper-planes. Best for low-dimensional data with axis-aligned metrics.
* `BallTree`: Recursively splits points into groups bounded by hyper-spheres. Suitable for high-dimensional data and arbitrary metrics.
* `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner. Useful as a baseline or for small datasets.
* `PeriodicTree`: Wraps one of the trees above to handle periodic boundary conditions. Essential for simulations with periodic domains.

These trees can be created using the following syntax:

```julia
KDTree(data, metric; leafsize, reorder)
BallTree(data, metric; leafsize, reorder)
BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for BruteTree

PeriodicTree(tree, bounds_min, bounds_max)
```

* `data`: The points to build the tree from, either as
Expand All @@ -29,6 +30,8 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for
* `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`).
* `leafsize`: Determines the number of points (default 25) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points.
* `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data.
* `tree`: An existing tree (`KDTree`, `BallTree`, or `BruteTree`) built from your data.
* `bounds_min`, `bounds_max`: Vectors defining the periodic domain boundaries. Use `Inf` in `bounds_max` for non-periodic dimensions.

All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation.
Note that this package is not suitable for very high dimensional points due to high compilation time and inefficient queries on the trees.
Expand All @@ -42,15 +45,16 @@ data = rand(3, 10^4)
kdtree = KDTree(data; leafsize = 25)
balltree = BallTree(data, Minkowski(3.5); reorder = false)
brutetree = BruteTree(data)
periodictree = PeriodicTree(kdtree, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
```

## k-Nearest Neighbor (kNN) Searches

A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods:

```julia
knn(tree, point[s], k [, skip=always_false]) -> idxs, dists
knn!(idxs, dists, tree, point, k [, skip=always_false])
knn(tree, point[s], k [, skip=Returns(false)]) -> idxs, dists
knn!(idxs, dists, tree, point, k [, skip=Returns(false)])
```

* `tree`: The tree instance.
Expand All @@ -62,7 +66,7 @@ knn!(idxs, dists, tree, point, k [, skip=always_false])
For the single closest neighbor, you can use `nn`:

```julia
nn(tree, point[s] [, skip=always_false]) -> idx, dist
nn(tree, point[s] [, skip=Returns(false)]) -> idx, dist
```

Examples:
Expand Down Expand Up @@ -169,6 +173,56 @@ inrange!(idxs, balltree, point, r)
neighborscount = inrangecount(balltree, point, r)
```

## Periodic Boundary Conditions

The `PeriodicTree` provides nearest neighbor searches with periodic boundary conditions.

### Creating a PeriodicTree

A `PeriodicTree` wraps an existing tree (`KDTree`, `BallTree`, or `BruteTree`) and handles periodic boundary conditions:

```julia
PeriodicTree(tree, bounds_min, bounds_max)
```

* `tree`: An existing tree built from your data
* `bounds_min`: Vector of minimum bounds for each dimension
* `bounds_max`: Vector of maximum bounds for each dimension (use `Inf` for non-periodic dimensions)

### Examples

**Basic periodic boundaries:**
```julia
using NearestNeighbors, StaticArrays

# Create data in a 2D periodic domain
data = [SVector(0.1, 0.2), SVector(0.8, 0.9), SVector(0.5, 0.5)]
kdtree = KDTree(data)

# Create periodic tree with bounds [0,1] × [0,1]
ptree = PeriodicTree(kdtree, [0.0, 0.0], [1.0, 1.0])

# Query near boundary - finds neighbors through periodic wrapping
query_point = [0.05, 0.15] # Near data[1] = [0.1, 0.2]
neighbor_point = [0.95, 0.85] # Near data[2] = [0.8, 0.9] via wrapping

idxs, dists = knn(ptree, query_point, 2)
# Finds both nearby points, including wrapped distances
```

**Mixed periodic/non-periodic dimensions:**
```julia
# 2D domain: x-periodic, y-infinite
data = [SVector(1.0, 2.0), SVector(9.0, 8.0)]
kdtree = KDTree(data)
ptree = PeriodicTree(kdtree, [0.0, 0.0], [10.0, Inf])

# Query near x-boundary finds wrapped neighbor
query = [0.5, 3.0]
idxs, dists = knn(ptree, query, 1)
# Finds data[1] with wrapped x-distance of 0.5 instead of 8.5
```

## Using On-Disk Data Sets

By default, trees store a copy of the `data` provided during construction. For data sets larger than available memory, `DataFreeTree` can be used to strip a tree of its data field and re-link it later.
Expand Down
22 changes: 22 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,38 @@ for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5)
(BallTree, "ball tree"))
tree = tree_type(data; leafsize = leafsize, reorder = reorder)
SUITE["build tree"]["$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable $(tree_type)($data; leafsize = $leafsize, reorder = $reorder)

# Add periodic tree benchmarks
bounds_min = zeros(dim)
bounds_max = ones(dim) # Unit cube [0,1]^dim
ptree = PeriodicTree(tree, bounds_min, bounds_max)
SUITE["build tree"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable PeriodicTree($tree, $bounds_min, $bounds_max)

# Add mixed periodic/non-periodic benchmarks (only for multi-dimensional cases)
if dim > 1
bounds_max_mixed = [ones(dim-1); Inf] # Last dimension non-periodic
ptree_mixed = PeriodicTree(tree, bounds_min, bounds_max_mixed)
SUITE["build tree"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable PeriodicTree($tree, $bounds_min, $bounds_max_mixed)
end

for input_size in (1, 1000)
input_data = rand(StableRNG(123), dim, input_size)
for k in (EXTENSIVE_BENCHMARK ? (1, 10) : 10)
SUITE["knn"]["$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($tree, $input_data, $k)
SUITE["knn"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($ptree, $input_data, $k)
if dim > 1
SUITE["knn"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($ptree_mixed, $input_data, $k)
end
end
perc = 0.01
V = π^(dim / 2) / gamma(dim / 2 + 1) * (1 / 2)^dim
r = (V * perc * gamma(dim / 2 + 1))^(1/dim)
r_formatted = @sprintf("%3.2e", r)
SUITE["inrange"]["$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($tree, $input_data, $r)
SUITE["inrange"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($ptree, $input_data, $r)
if dim > 1
SUITE["inrange"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($ptree_mixed, $input_data, $r)
end
end
end
end
Expand Down
18 changes: 10 additions & 8 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ using Distances
import Distances: PreMetric, Metric, UnionMinkowskiMetric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters

using StaticArrays
import Base.show

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree
export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs
export injectdata

Expand Down Expand Up @@ -48,18 +47,21 @@ end
get_T(::Type{T}) where {T <: AbstractFloat} = T
get_T(::T) where {T} = Float64

include("evaluation.jl")
include("tree_data.jl")
include("datafreetree.jl")
include("knn.jl")
include("inrange.jl")
get_tree(tree::NNTree) = tree

include("hyperspheres.jl")
include("hyperrectangles.jl")
include("evaluation.jl")
include("utilities.jl")
include("tree_data.jl")
include("tree_ops.jl")
include("brute_tree.jl")
include("kd_tree.jl")
include("ball_tree.jl")
include("tree_ops.jl")
include("periodic_tree.jl")
include("datafreetree.jl")
include("knn.jl")
include("inrange.jl")

for dim in (2, 3)
tree = KDTree(rand(dim, 10))
Expand Down
34 changes: 17 additions & 17 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function _knn(tree::BallTree,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {F}
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, false)
return
end

Expand All @@ -157,9 +157,9 @@ function knn_kernel!(tree::BallTree{V},
point::AbstractArray,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F, unique::Bool) where {V, F}
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, unique)
return
end

Expand All @@ -171,14 +171,14 @@ function knn_kernel!(tree::BallTree{V},

if left_dist <= best_dists[1] || right_dist <= best_dists[1]
if left_dist < right_dist
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique)
if right_dist <= best_dists[1]
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique)
end
else
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique)
if left_dist <= best_dists[1]
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip)
knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique)
end
end
end
Expand All @@ -188,16 +188,19 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F) where {V, F}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, false) # Call the recursive range finder
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F,
unique::Bool) where {F}

if index > length(tree.hyper_spheres)
return 0
Expand All @@ -215,19 +218,16 @@ function inrange_kernel!(tree::BallTree,
# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r
return add_points_inrange!(idx_in_ball, tree, index, point, r)
return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique)
end

count = 0

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses_fast(dist, tree.metric, sphere, query_ball)
count += addall(tree, index, idx_in_ball)
return addall(tree, index, idx_in_ball, skip, unique)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, unique) +
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, unique)
end
return count
end
21 changes: 16 additions & 5 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,25 @@ function _knn(tree::BruteTree{V},
best_dists::AbstractVector,
skip::F) where {V, F}

knn_kernel!(tree, point, best_idxs, best_dists, skip)
knn_kernel!(tree, point, best_idxs, best_dists, skip, false)
return
end

function knn_kernel!(tree::BruteTree{V},
point::AbstractVector,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {V, F}
skip::F,
unique::Bool) where {V, F}
for i in 1:length(tree.data)
if skip(i)
continue
end

#if unique && i in best_idxs
# continue
#end

dist_d = evaluate(tree.metric, tree.data[i], point)
if dist_d <= best_dists[1]
best_dists[1] = dist_d
Expand All @@ -71,17 +76,23 @@ end
function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
return inrange_kernel!(tree, point, radius, idx_in_ball)
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::F,) where {F}
return inrange_kernel!(tree, point, radius, idx_in_ball, skip, false)
end


function inrange_kernel!(tree::BruteTree,
point::AbstractVector,
r::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
idx_in_ball::Union{Nothing, Vector{<:Integer}},
skip::Function,
unique::Bool)
count = 0
for i in 1:length(tree.data)
if skip(i)
continue
end
d = evaluate(tree.metric, tree.data[i], point)
if d <= r
count += 1
Expand Down
Loading
Loading