From 583b6e401e6c8b7db14c89c6ab06881b0b4e33e2 Mon Sep 17 00:00:00 2001 From: KristofferC Date: Thu, 20 Jun 2024 17:27:07 +0200 Subject: [PATCH] Implement comprehensive periodic tree improvements and optimizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix periodic tree implementation with proper boundary validation and mixed dimension support - Add mirror box pruning optimization for KNN searches to skip irrelevant periodic images - Implement comprehensive tests covering mixed periodic/non-periodic dimensions and boundary cases - Add extensive periodic tree benchmarks to benchmark suite for performance tracking - Add detailed PeriodicTree documentation with examples and usage patterns to README The periodic tree now properly handles: - Mixed periodic/non-periodic dimensions (using Inf for non-periodic bounds) - Boundary validation ensuring all data points are within the periodic box - Performance optimization by pruning irrelevant mirror boxes during searches - Comprehensive error handling for invalid box configurations 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- README.md | 70 +++++++- benchmark/benchmarks.jl | 22 +++ src/NearestNeighbors.jl | 18 +- src/ball_tree.jl | 34 ++-- src/brute_tree.jl | 21 ++- src/inrange.jl | 32 ++-- src/kd_tree.jl | 34 ++-- src/knn.jl | 34 ++-- src/periodic_tree.jl | 240 +++++++++++++++++++++++++ src/tree_ops.jl | 47 +++-- src/utilities.jl | 10 +- test/runtests.jl | 1 + test/test_knn.jl | 2 +- test/test_monkey.jl | 3 +- test/test_periodic.jl | 377 ++++++++++++++++++++++++++++++++++++++++ 15 files changed, 839 insertions(+), 106 deletions(-) create mode 100644 src/periodic_tree.jl create mode 100644 test/test_periodic.jl diff --git a/README.md b/README.md index 54081da..c54c062 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,12 @@ ## Creating a Tree -There are currently three types of trees available: +There are currently four types of trees available: -* `KDTree`: Recursively splits points into groups using hyper-planes. -* `BallTree`: Recursively splits points into groups bounded by hyper-spheres. -* `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner. +* `KDTree`: Recursively splits points into groups using hyper-planes. Best for low-dimensional data with axis-aligned metrics. +* `BallTree`: Recursively splits points into groups bounded by hyper-spheres. Suitable for high-dimensional data and arbitrary metrics. +* `BruteTree`: Not actually a tree. It linearly searches all points in a brute force manner. Useful as a baseline or for small datasets. +* `PeriodicTree`: Wraps one of the trees above to handle periodic boundary conditions. Essential for simulations with periodic domains. These trees can be created using the following syntax: @@ -20,7 +21,7 @@ These trees can be created using the following syntax: KDTree(data, metric; leafsize, reorder) BallTree(data, metric; leafsize, reorder) BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for BruteTree - +PeriodicTree(tree, bounds_min, bounds_max) ``` * `data`: The points to build the tree from, either as @@ -29,6 +30,8 @@ BruteTree(data, metric; leafsize, reorder) # leafsize and reorder are unused for * `metric`: The `Metric` (from `Distances.jl`) to use, defaults to `Euclidean`. `KDTree` works with axis-aligned metrics: `Euclidean`, `Chebyshev`, `Minkowski`, and `Cityblock` while for `BallTree` and `BruteTree` other pre-defined `Metric`s can be used as well as custom metrics (that are subtypes of `Metric`). * `leafsize`: Determines the number of points (default 25) at which to stop splitting the tree. There is a trade-off between tree traversal and evaluating the metric for an increasing number of points. * `reorder`: If `true` (default), during tree construction this rearranges points to improve cache locality during querying. This will create a copy of the original data. +* `tree`: An existing tree (`KDTree`, `BallTree`, or `BruteTree`) built from your data. +* `bounds_min`, `bounds_max`: Vectors defining the periodic domain boundaries. Use `Inf` in `bounds_max` for non-periodic dimensions. All trees in `NearestNeighbors.jl` are static, meaning points cannot be added or removed after creation. Note that this package is not suitable for very high dimensional points due to high compilation time and inefficient queries on the trees. @@ -42,6 +45,7 @@ data = rand(3, 10^4) kdtree = KDTree(data; leafsize = 25) balltree = BallTree(data, Minkowski(3.5); reorder = false) brutetree = BruteTree(data) +periodictree = PeriodicTree(kdtree, [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]) ``` ## k-Nearest Neighbor (kNN) Searches @@ -49,8 +53,8 @@ brutetree = BruteTree(data) A kNN search finds the `k` nearest neighbors to a given point or points. This is done with the methods: ```julia -knn(tree, point[s], k [, skip=always_false]) -> idxs, dists -knn!(idxs, dists, tree, point, k [, skip=always_false]) +knn(tree, point[s], k [, skip=Returns(false)]) -> idxs, dists +knn!(idxs, dists, tree, point, k [, skip=Returns(false)]) ``` * `tree`: The tree instance. @@ -62,7 +66,7 @@ knn!(idxs, dists, tree, point, k [, skip=always_false]) For the single closest neighbor, you can use `nn`: ```julia -nn(tree, point[s] [, skip=always_false]) -> idx, dist +nn(tree, point[s] [, skip=Returns(false)]) -> idx, dist ``` Examples: @@ -169,6 +173,56 @@ inrange!(idxs, balltree, point, r) neighborscount = inrangecount(balltree, point, r) ``` +## Periodic Boundary Conditions + +The `PeriodicTree` provides nearest neighbor searches with periodic boundary conditions. + +### Creating a PeriodicTree + +A `PeriodicTree` wraps an existing tree (`KDTree`, `BallTree`, or `BruteTree`) and handles periodic boundary conditions: + +```julia +PeriodicTree(tree, bounds_min, bounds_max) +``` + +* `tree`: An existing tree built from your data +* `bounds_min`: Vector of minimum bounds for each dimension +* `bounds_max`: Vector of maximum bounds for each dimension (use `Inf` for non-periodic dimensions) + +### Examples + +**Basic periodic boundaries:** +```julia +using NearestNeighbors, StaticArrays + +# Create data in a 2D periodic domain +data = [SVector(0.1, 0.2), SVector(0.8, 0.9), SVector(0.5, 0.5)] +kdtree = KDTree(data) + +# Create periodic tree with bounds [0,1] × [0,1] +ptree = PeriodicTree(kdtree, [0.0, 0.0], [1.0, 1.0]) + +# Query near boundary - finds neighbors through periodic wrapping +query_point = [0.05, 0.15] # Near data[1] = [0.1, 0.2] +neighbor_point = [0.95, 0.85] # Near data[2] = [0.8, 0.9] via wrapping + +idxs, dists = knn(ptree, query_point, 2) +# Finds both nearby points, including wrapped distances +``` + +**Mixed periodic/non-periodic dimensions:** +```julia +# 2D domain: x-periodic, y-infinite +data = [SVector(1.0, 2.0), SVector(9.0, 8.0)] +kdtree = KDTree(data) +ptree = PeriodicTree(kdtree, [0.0, 0.0], [10.0, Inf]) + +# Query near x-boundary finds wrapped neighbor +query = [0.5, 3.0] +idxs, dists = knn(ptree, query, 1) +# Finds data[1] with wrapped x-distance of 0.5 instead of 8.5 +``` + ## Using On-Disk Data Sets By default, trees store a copy of the `data` provided during construction. For data sets larger than available memory, `DataFreeTree` can be used to strip a tree of its data field and re-link it later. diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index b987905..01703ea 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -20,16 +20,38 @@ for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5) (BallTree, "ball tree")) tree = tree_type(data; leafsize = leafsize, reorder = reorder) SUITE["build tree"]["$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable $(tree_type)($data; leafsize = $leafsize, reorder = $reorder) + + # Add periodic tree benchmarks + bounds_min = zeros(dim) + bounds_max = ones(dim) # Unit cube [0,1]^dim + ptree = PeriodicTree(tree, bounds_min, bounds_max) + SUITE["build tree"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable PeriodicTree($tree, $bounds_min, $bounds_max) + + # Add mixed periodic/non-periodic benchmarks (only for multi-dimensional cases) + if dim > 1 + bounds_max_mixed = [ones(dim-1); Inf] # Last dimension non-periodic + ptree_mixed = PeriodicTree(tree, bounds_min, bounds_max_mixed) + SUITE["build tree"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable PeriodicTree($tree, $bounds_min, $bounds_max_mixed) + end + for input_size in (1, 1000) input_data = rand(StableRNG(123), dim, input_size) for k in (EXTENSIVE_BENCHMARK ? (1, 10) : 10) SUITE["knn"]["$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($tree, $input_data, $k) + SUITE["knn"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($ptree, $input_data, $k) + if dim > 1 + SUITE["knn"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, k = $k"] = @benchmarkable knn($ptree_mixed, $input_data, $k) + end end perc = 0.01 V = π^(dim / 2) / gamma(dim / 2 + 1) * (1 / 2)^dim r = (V * perc * gamma(dim / 2 + 1))^(1/dim) r_formatted = @sprintf("%3.2e", r) SUITE["inrange"]["$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($tree, $input_data, $r) + SUITE["inrange"]["Periodic$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($ptree, $input_data, $r) + if dim > 1 + SUITE["inrange"]["PeriodicMixed$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($ptree_mixed, $input_data, $r) + end end end end diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6edbc1c..82a856c 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -4,9 +4,8 @@ using Distances import Distances: PreMetric, Metric, UnionMinkowskiMetric, result_type, eval_reduce, eval_end, eval_op, eval_start, evaluate, parameters using StaticArrays -import Base.show -export NNTree, BruteTree, KDTree, BallTree, DataFreeTree +export NNTree, BruteTree, KDTree, BallTree, DataFreeTree, PeriodicTree export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs export injectdata @@ -48,18 +47,21 @@ end get_T(::Type{T}) where {T <: AbstractFloat} = T get_T(::T) where {T} = Float64 -include("evaluation.jl") -include("tree_data.jl") -include("datafreetree.jl") -include("knn.jl") -include("inrange.jl") +get_tree(tree::NNTree) = tree + include("hyperspheres.jl") include("hyperrectangles.jl") +include("evaluation.jl") include("utilities.jl") +include("tree_data.jl") +include("tree_ops.jl") include("brute_tree.jl") include("kd_tree.jl") include("ball_tree.jl") -include("tree_ops.jl") +include("periodic_tree.jl") +include("datafreetree.jl") +include("knn.jl") +include("inrange.jl") for dim in (2, 3) tree = KDTree(rand(dim, 10)) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 7127952..ca8ba8e 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -147,7 +147,7 @@ function _knn(tree::BallTree, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, skip::F) where {F} - knn_kernel!(tree, 1, point, best_idxs, best_dists, skip) + knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, false) return end @@ -157,9 +157,9 @@ function knn_kernel!(tree::BallTree{V}, point::AbstractArray, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F, unique::Bool) where {V, F} if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip) + add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip, unique) return end @@ -171,14 +171,14 @@ function knn_kernel!(tree::BallTree{V}, if left_dist <= best_dists[1] || right_dist <= best_dists[1] if left_dist < right_dist - knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique) if right_dist <= best_dists[1] - knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique) end else - knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getright(index), point, best_idxs, best_dists, skip, unique) if left_dist <= best_dists[1] - knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip) + knn_kernel!(tree, getleft(index), point, best_idxs, best_dists, skip, unique) end end end @@ -188,16 +188,19 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V} + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {V, F} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + return inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip, false) # Call the recursive range finder end function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F, + unique::Bool) where {F} if index > length(tree.hyper_spheres) return 0 @@ -215,19 +218,16 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, query_ball.r) : query_ball.r - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique) end - count = 0 - # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses_fast(dist, tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball) + return addall(tree, index, idx_in_ball, skip, unique) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + return inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip, unique) + + inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip, unique) end - return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index bc882c5..9e6fb37 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -45,7 +45,7 @@ function _knn(tree::BruteTree{V}, best_dists::AbstractVector, skip::F) where {V, F} - knn_kernel!(tree, point, best_idxs, best_dists, skip) + knn_kernel!(tree, point, best_idxs, best_dists, skip, false) return end @@ -53,12 +53,17 @@ function knn_kernel!(tree::BruteTree{V}, point::AbstractVector, best_idxs::AbstractVector{<:Integer}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F, + unique::Bool) where {V, F} for i in 1:length(tree.data) if skip(i) continue end + #if unique && i in best_idxs + # continue + #end + dist_d = evaluate(tree.metric, tree.data[i], point) if dist_d <= best_dists[1] best_dists[1] = dist_d @@ -71,17 +76,23 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) - return inrange_kernel!(tree, point, radius, idx_in_ball) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F,) where {F} + return inrange_kernel!(tree, point, radius, idx_in_ball, skip, false) end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}}) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::Function, + unique::Bool) count = 0 for i in 1:length(tree.data) + if skip(i) + continue + end d = evaluate(tree.metric, tree.data[i], point) if d <= r count += 1 diff --git a/src/inrange.jl b/src/inrange.jl index f10675e..345bdc8 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -18,24 +18,28 @@ See also: `inrange!`, `inrangecount`. function inrange(tree::NNTree, points::AbstractVector{T}, radius::Number, - sortres=false) where {T <: AbstractVector} + sortres=false, + skip::F = Returns(false)) where {T <: AbstractVector, F} check_input(tree, points) check_radius(radius) idxs = [Vector{Int}() for _ in 1:length(points)] for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) + inrange_point!(tree, points[i], radius, sortres, idxs[i], skip) end return idxs end -function inrange_point!(tree, point, radius, sortres, idx) - count = _inrange(tree, point, radius, idx) +inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} = _inrange_point!(tree, point, radius, sortres, idx, skip) + +function _inrange_point!(tree, point, radius, sortres, idx, skip::F) where {F} + count = _inrange(tree, point, radius, idx, skip) if idx !== nothing - if tree.reordered + inner_tree = get_tree(tree) + if inner_tree.reordered @inbounds for j in 1:length(idx) - idx[j] = tree.indices[idx[j]] + idx[j] = inner_tree.indices[idx[j]] end end sortres && sort!(idx) @@ -57,11 +61,11 @@ Useful to avoid allocations or specify the element type of the output vector. See also: `inrange`, `inrangecount`. """ -function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} +function inrange!(idxs::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip=Returns(false)) where {V, T <: Number} check_input(tree, point) check_radius(radius) length(idxs) == 0 || throw(ArgumentError("idxs must be empty")) - inrange_point!(tree, point, radius, sortres, idxs) + inrange_point!(tree, point, radius, sortres, idxs, skip) return idxs end @@ -74,7 +78,7 @@ function inrange(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, sor inrange_matrix(tree, points, radius, Val(dim), sortres) end -function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres) where {V, T <: Number, dim} +function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Number, ::Val{dim}, sortres, skip::F=Returns(false)) where {V, T <: Number, dim, F} # TODO: DRY with inrange for AbstractVector check_input(tree, points) check_radius(radius) @@ -83,7 +87,7 @@ function inrange_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, radius::Numb for i in 1:n_points point = SVector{dim,T}(ntuple(j -> points[j, i], Val(dim))) - inrange_point!(tree, point, radius, sortres, idxs[i]) + inrange_point!(tree, point, radius, sortres, idxs[i], skip) end return idxs end @@ -101,18 +105,18 @@ Count all the points in the tree which are closer than `radius` to `points`. # Returns - `count`: Number of points within the radius (integer for single point, vector for multiple points) """ -function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number) where {V, T <: Number} +function inrangecount(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, skip::F=Returns(false)) where {V, T <: Number, F} check_input(tree, point) check_radius(radius) - return inrange_point!(tree, point, radius, false, nothing) + return inrange_point!(tree, point, radius, false, nothing, skip) end function inrangecount(tree::NNTree, points::AbstractVector{T}, - radius::Number) where {T <: AbstractVector} + radius::Number, skip::F=Returns(false)) where {T <: AbstractVector, F} check_input(tree, points) check_radius(radius) - return inrange_point!.(Ref(tree), points, radius, false, nothing) + return inrange_point!.(Ref(tree), points, radius, false, nothing, skip) end function inrangecount(tree::NNTree{V}, point::AbstractMatrix{T}, radius::Number) where {V, T <: Number} diff --git a/src/kd_tree.jl b/src/kd_tree.jl index b84716f..9dfdb34 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -143,7 +143,7 @@ function _knn(tree::KDTree, best_dists::AbstractVector, skip::F) where {F} init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) - knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip) + knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip, false) @simd for i in eachindex(best_dists) @inbounds best_dists[i] = eval_end(tree.metric, best_dists[i]) end @@ -156,10 +156,11 @@ function knn_kernel!(tree::KDTree{V}, best_dists::AbstractVector, min_dist, hyper_rec::HyperRectangle, - skip::F) where {V, F} + skip::F, + unique::Bool) where {V, F} # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip) + add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip, unique) return end @@ -181,7 +182,7 @@ function knn_kernel!(tree::KDTree{V}, hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) end # Always call closer sub tree - knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip) + knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip, unique) if M isa Chebyshev new_min = get_min_distance_no_end(M, hyper_rec_far, point) @@ -190,7 +191,7 @@ function knn_kernel!(tree::KDTree{V}, end if new_min < best_dists[1] - knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip) + knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip, unique) end return end @@ -199,17 +200,17 @@ function _inrange( tree::KDTree, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[] - ) + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {F} init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) init_max_contribs = get_max_distance_contributions(tree.metric, tree.hyper_rec, point) init_max = tree.metric isa Chebyshev ? maximum(init_max_contribs) : sum(init_max_contribs) return inrange_kernel!( tree, 1, point, eval_pow(tree.metric, radius), idx_in_ball, - tree.hyper_rec, init_min, init_max_contribs, init_max - ) + tree.hyper_rec, init_min, init_max_contribs, init_max, skip, false) end + # Explicitly check the distance between leaf node and point while traversing function inrange_kernel!( tree::KDTree, @@ -220,20 +221,21 @@ function inrange_kernel!( hyper_rec::HyperRectangle, min_dist, max_dist_contribs::SVector, - max_dist - ) + max_dist, + skip::F, + unique::Bool) where {F} # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r return 0 end if max_dist < r - return addall(tree, index, idx_in_ball) + return addall(tree, index, idx_in_ball, skip, unique) end # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, r) + return add_points_inrange!(idx_in_ball, tree, index, point, r, skip, unique) end split_val = tree.split_vals[index] @@ -259,7 +261,7 @@ function inrange_kernel!( old_contrib = max_dist_contribs[split_dim] if split_diff > 0 # Point is to the right - # Close subtree: split_val as new min, far subtree: split_val as new max + # Close subtree: split_val as new min, far subtree: split_val as new max new_contrib_close = get_max_distance_contribution_single(M, point[split_dim], split_val, hyper_rec.maxes[split_dim], split_dim) new_contrib_far = get_max_distance_contribution_single(M, point[split_dim], hyper_rec.mins[split_dim], split_val, split_dim) else @@ -274,7 +276,7 @@ function inrange_kernel!( new_max_dist_close = M isa Chebyshev ? maximum(new_max_contribs_close) : max_dist - old_contrib + new_contrib_close # Call closer sub tree - count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, new_max_contribs_close, new_max_dist_close) + count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist, new_max_contribs_close, new_max_dist_close, skip, unique) # Compute new min distance for far subtree new_min = M isa Chebyshev ? get_min_distance_no_end(M, hyper_rec_far, point) : update_new_min(M, min_dist, hyper_rec, p_dim, split_dim, split_val) @@ -284,6 +286,6 @@ function inrange_kernel!( new_max_dist_far = M isa Chebyshev ? maximum(new_max_contribs_far) : max_dist - old_contrib + new_contrib_far # Call further sub tree - count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, new_max_contribs_far, new_max_dist_far) + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min, new_max_contribs_far, new_max_dist_far, skip, unique) return count end diff --git a/src/knn.jl b/src/knn.jl index 7e13659..b8ec0a2 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -1,5 +1,5 @@ function check_k(tree, k) - if k > length(tree.data) || k < 0 + if k > length(get_tree(tree).data) || k < 0 throw(ArgumentError("k > number of points in tree or < 0")) end end @@ -22,7 +22,7 @@ in the `tree`. See also: `knn!`, `nn`. """ -function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function} +function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: AbstractVector, F<:Function} check_input(tree, points) check_k(tree, k) n_points = length(points) @@ -34,19 +34,23 @@ function knn(tree::NNTree{V}, points::AbstractVector{T}, k::Int, sortres=false, return idxs, dists end -function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} +knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} = + _knn_point!(tree, point, sortres, dist, idx, skip) + +function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} fill!(idx, -1) fill!(dist, typemax(get_T(eltype(V)))) _knn(tree, point, idx, dist, skip) - if skip !== always_false + if skip !== Returns(false) skipped_idxs = findall(==(-1), idx) deleteat!(idx, skipped_idxs) deleteat!(dist, skipped_idxs) end sortres && heap_sort_inplace!(dist, idx) - if tree.reordered + inner_tree = get_tree(tree) + if inner_tree.reordered for j in eachindex(idx) - @inbounds idx[j] = tree.indices[idx[j]] + @inbounds idx[j] = inner_tree.indices[idx[j]] end end return @@ -68,7 +72,7 @@ Useful to avoid allocations or specify the element type of the output vectors. See also: `knn`, `nn`. """ -function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} +function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} check_k(tree, k) length(idxs) == k || throw(ArgumentError("idxs must be of length k")) length(dists) == k || throw(ArgumentError("dists must be of length k")) @@ -76,19 +80,15 @@ function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTr return idxs, dists end -function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} +function knn(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} idx = Vector{Int}(undef, k) dist = Vector{get_T(eltype(V))}(undef, k) return knn!(idx, dist, tree, point, k, sortres, skip) end -function knn(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function} - dim = size(points, 1) - knn_matrix(tree, points, k, Val(dim), sortres, skip) -end +function knn(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, sortres=false, skip::F=Returns(false)) where {V, T <: Number, F<:Function} + dim = length(V) -# Function barrier -function knn_matrix(tree::NNTree{V}, points::AbstractMatrix{T}, k::Int, ::Val{dim}, sortres=false, skip::F=always_false) where {V, T <: Number, F<:Function, dim} # TODO: DRY with knn for AbstractVector check_input(tree, points) check_k(tree, k) @@ -120,9 +120,9 @@ Performs a lookup of the single nearest neighbor to the `point(s)` from the data See also: `knn`. """ -nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> only -nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _onlyeach -nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _onlyeach +nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> only +nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _onlyeach +nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _onlyeach _nn(tree, points, skip) = knn(tree, points, 1, false, skip) diff --git a/src/periodic_tree.jl b/src/periodic_tree.jl new file mode 100644 index 0000000..263f0f9 --- /dev/null +++ b/src/periodic_tree.jl @@ -0,0 +1,240 @@ +""" + PeriodicTree(tree::NNTree, bounds_min, bounds_max) -> PeriodicTree + +Creates a periodic wrapper around an existing nearest neighbor tree (KDTree, BallTree, or BruteTree) +that handles periodic boundary conditions. + +# Arguments +- `tree::NNTree`: The underlying tree structure (KDTree, BallTree, or BruteTree) +- `bounds_min`: Vector of minimum bounds for each dimension +- `bounds_max`: Vector of maximum bounds for each dimension + +# Requirements +- All data points in the tree must be within the specified periodic box bounds +- Box dimensions must be positive and finite (except for non-periodic dimensions) + +# Returns +- `PeriodicTree`: A tree that performs nearest neighbor searches with periodic boundary conditions + +# Algorithm +The periodic tree works by creating "mirror images" of the query point by shifting it by multiples +of the box dimensions. For each periodic image, it searches the underlying tree and combines results +while ensuring no duplicates are returned. + +# Performance Notes +- For best performance, ensure search radii are ≤ half the smallest box dimension +- Larger radii will still work correctly but may perform redundant searches +- Dimensions with infinite bounds are treated as non-periodic + +# Examples +```julia +using NearestNeighbors, StaticArrays + +# Create some 2D data +data = [SVector(1.0, 2.0), SVector(3.0, 4.0), SVector(7.0, 8.0)] +bounds_min = [0.0, 0.0] +bounds_max = [10.0, 10.0] + +# Create periodic tree +kdtree = KDTree(data) +ptree = PeriodicTree(kdtree, bounds_min, bounds_max) + +# Search near boundary - finds points through periodic wrapping +query_point = [9.0, 1.0] +idxs, dists = knn(ptree, query_point, 2) +``` +""" +struct PeriodicTree{V<:AbstractVector, M, Tree <: NNTree{V, M}, D} <: NNTree{V,M} + tree::Tree + bbox::HyperRectangle{V} + combos::Vector{SVector{D, Int}} + box_widths::SVector{D, eltype(V)} + + function PeriodicTree(tree::NNTree{V,M}, bounds_min, bounds_max) where {V,M} + dim = length(V) + if length(bounds_min) != dim || length(bounds_max) != dim + throw(ArgumentError("Bounding box dimensions do not match data dimensions")) + end + + # Store finite box widths, use 0.0 for non-periodic dimensions to avoid Inf * 0 = NaN + box_widths = SVector{dim}( + isfinite(bounds_max[i] - bounds_min[i]) ? bounds_max[i] - bounds_min[i] : 0.0 + for i in 1:dim + ) + + # Check for valid box dimensions (finite dimensions must be positive) + for i in 1:dim + actual_width = bounds_max[i] - bounds_min[i] + if isfinite(actual_width) && actual_width <= 0 + throw(ArgumentError("Box width in dimension $i must be positive, got $actual_width")) + end + end + + # Validate that all data points are within the periodic box bounds + # This is important for correct periodic behavior + for (idx, point) in enumerate(tree.data) + for i in 1:dim + if point[i] < bounds_min[i] || point[i] > bounds_max[i] + throw(ArgumentError("Data point $idx has coordinate $(point[i]) in dimension $i, which is outside the periodic box bounds [$(bounds_min[i]), $(bounds_max[i])]")) + end + end + end + + # Find periodic dimensions (those with non-zero box widths) + periodic_dims = findall(>(0), box_widths) + n_periodic = length(periodic_dims) + + # Generate combinations only for periodic dimensions + if n_periodic == 0 + # No periodic dimensions - only search original box + combos_reordered = [zero(SVector{dim, Int})] + else + # Generate all combinations of [-1, 0, 1] for periodic dimensions only + periodic_ranges = ntuple(i -> -1:1, Val(n_periodic)) + periodic_combos = collect(Iterators.product(periodic_ranges...)) + + # Convert to full-dimension combo vectors + combos = Vector{SVector{dim, Int}}() + for combo_vals in periodic_combos + full_combo = zeros(Int, dim) + for (i, dim_idx) in enumerate(periodic_dims) + full_combo[dim_idx] = combo_vals[i] + end + push!(combos, SVector{dim, Int}(full_combo)) + end + + # Put the (0, 0, 0, ...) combo first to search the original box first + # This is important for performance as the original box often contains the closest points + zero_combo = zero(SVector{dim, Int}) + filtered_combos = filter(x -> x != zero_combo, combos) + combos_reordered = pushfirst!(filtered_combos, zero_combo) + end + + return new{V, M, typeof(tree), dim}( + tree, + HyperRectangle(SVector{dim}(bounds_min), SVector{dim}(bounds_max)), + combos_reordered, + box_widths + ) + end +end + +get_tree(tree::PeriodicTree) = tree.tree + + +function Base.show(io::IO, tree::PeriodicTree{V}) where {V} + println(io, "Periodic Tree: $(typeof(tree.tree))") + + # Show periodic and non-periodic dimensions clearly + periodic_dims = findall(>(0), tree.box_widths) + non_periodic_dims = findall(==(0), tree.box_widths) + + println(io, " Dimensions: ", length(V)) + if !isempty(periodic_dims) + println(io, " Periodic ($(length(periodic_dims))): ", periodic_dims) + for dim in periodic_dims + println(io, " Dim $dim: [$(tree.bbox.mins[dim]), $(tree.bbox.maxes[dim])] (width: $(tree.box_widths[dim]))") + end + end + if !isempty(non_periodic_dims) + println(io, " Non-periodic ($(length(non_periodic_dims))): ", non_periodic_dims) + end + + println(io, " Number of points: ", length(tree.tree.data)) + println(io, " Metric: ", tree.tree.metric) + print(io, " Reordered: ", tree.tree.reordered) +end + +function _knn(tree::PeriodicTree{V,M}, + point::AbstractVector, + best_idxs::AbstractVector{<:Integer}, + best_dists::AbstractVector, + skip::F) where {V, M, F} + + # Search all periodic mirror boxes + # Each combo represents a different "image" of the periodic box + # e.g., (0,0) = original, (1,0) = shifted right by box_width, (-1,1) = shifted left and up + for combo in tree.combos + # Create the shift vector: multiply box dimensions by the combo coefficients + shift_vector = tree.box_widths .* combo + # Create a "mirror image" of the query point in this periodic box + point_shifted = point + shift_vector + + # Calculate minimum distance from shifted point to the original bounding box + min_dist_to_canonical = get_min_distance_no_end(tree.tree.metric, tree.bbox, point_shifted) + + # Optimization: Skip mirror boxes that can't improve current results + # If minimum possible distance is >= current k-th nearest distance, skip this mirror box + if eval_end(tree.tree.metric, min_dist_to_canonical) >= best_dists[1] + continue + end + + # Search the underlying tree with the shifted query point + # The 'true' parameter enables uniqueness checking to prevent duplicate results + if tree.tree isa KDTree + knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, min_dist_to_canonical, tree.tree.hyper_rec, skip, true) + elseif tree.tree isa BallTree + knn_kernel!(tree.tree, 1, point_shifted, best_idxs, best_dists, skip, true) + else + @assert tree.tree isa BruteTree + knn_kernel!(tree.tree, point_shifted, best_idxs, best_dists, skip, true) + end + end + + # For KDTree, we need to finalize the distance calculations + # This is because KDTree uses squared distances internally for efficiency + if tree.tree isa KDTree + @simd for i in eachindex(best_dists) + @inbounds best_dists[i] = eval_end(tree.tree.metric, best_dists[i]) + end + end + + # Verify no duplicates were returned (should be guaranteed by unique=true above) + @assert allunique(best_idxs) + return +end + +function _inrange(tree::PeriodicTree{V}, + point::AbstractVector, + radius::Number, + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + skip::F) where {V, F} + + # Search all periodic mirror boxes for points within the given radius + for combo in tree.combos + # Create the shift vector for this mirror box + shift_vector = tree.box_widths .* combo + # Create a "mirror image" of the query point + point_shifted = point + shift_vector + + # Performance optimization: skip mirror boxes that are too far away + # If the closest possible point in the original box is farther than radius, + # then no points in this mirror box can be within radius + min_dist_to_bbox = get_min_distance_no_end(tree.tree.metric, tree.bbox, point_shifted) + if eval_end(tree.tree.metric, min_dist_to_bbox) > radius + continue + end + + # Search the underlying tree with the shifted query point + # The 'true' parameter enables uniqueness checking to prevent duplicate results + if tree.tree isa KDTree + # KDTree requires additional distance computation parameters + max_dist_contribs = get_max_distance_contributions(tree.tree.metric, tree.bbox, point_shifted) + max_dist = tree.tree.metric isa Chebyshev ? maximum(max_dist_contribs) : sum(max_dist_contribs) + inrange_kernel!(tree.tree, 1, point_shifted, eval_op(tree.tree.metric, radius, zero(min_dist_to_bbox)), idx_in_ball, + tree.tree.hyper_rec, min_dist_to_bbox, max_dist_contribs, max_dist, skip, true) + elseif tree.tree isa BallTree + # BallTree uses a hypersphere for range queries + ball = HyperSphere(convert(V, point_shifted), convert(eltype(V), radius)) + inrange_kernel!(tree.tree, 1, point_shifted, ball, idx_in_ball, skip, true) + else + @assert tree.tree isa BruteTree + # BruteTree has the simplest interface + inrange_kernel!(tree.tree, point_shifted, radius, idx_in_ball, skip, true) + end + end + + # Verify no duplicates were returned (should be guaranteed by unique=true above) + @assert allunique(idx_in_ball) + return +end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 39338cf..ddc796a 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -4,7 +4,7 @@ @inline getparent(i::Int) = div(i, 2) @inline isleaf(n_internal_nodes::Int, idx::Int) = idx > n_internal_nodes -function show(io::IO, tree::NNTree{V}) where {V} +function Base.show(io::IO, tree::NNTree{V}) where {V} println(io, typeof(tree)) println(io, " Number of points: ", length(tree.data)) println(io, " Dimensions: ", length(V)) @@ -92,15 +92,25 @@ end # Uses a heap for fast insertion. @inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{<:Integer}, tree::NNTree, index::Int, point::AbstractVector, - do_end::Bool, skip::F) where {F} + do_end::Bool, skip::F, unique::Bool) where {F} for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] dist_d = evaluate_maybe_end(tree.metric, tree.data[idx], point, do_end) - if dist_d <= best_dists[1] - if skip(tree.indices[z]) - continue + if dist_d < best_dists[1] + if unique + idx_existing = findfirst(==(idx), best_idxs) + if idx_existing !== nothing + dist = best_dists[idx_existing] + if dist_d < dist + best_dists[idx_existing] = dist_d + percolate_down!(best_dists, best_idxs, dist_d, idx, idx_existing) + end + continue + end end - best_dists[1] = dist_d best_idxs[1] = idx percolate_down!(best_dists, best_idxs, dist_d, idx) @@ -115,10 +125,17 @@ end # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. @inline function add_points_inrange!(idx_in_ball::Union{Nothing, AbstractVector{<:Integer}}, tree::NNTree, - index::Int, point::AbstractVector, r::Number) + index::Int, point::AbstractVector, r::Number, skip::Function, + unique::Bool) count = 0 for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] + if unique && idx in idx_in_ball + continue + end if check_in_range(tree.metric, tree.data[idx], point, r) count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) @@ -138,18 +155,24 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}) +function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:Integer}}, skip::Function, unique::Bool) tree_data = tree.tree_data - count = 0 if isleaf(tree_data.n_internal_nodes, index) + count = 0 for z in get_leaf_range(tree_data, index) + if skip(tree.indices[z]) + continue + end idx = tree.reordered ? z : tree.indices[z] + if unique && idx in idx_in_ball + continue + end count += 1 idx_in_ball !== nothing && push!(idx_in_ball, idx) end + return count else - count += addall(tree, getleft(index), idx_in_ball) - count += addall(tree, getright(index), idx_in_ball) + return addall(tree, getleft(index), idx_in_ball, skip, unique) + + addall(tree, getright(index), idx_in_ball, skip, unique) end - return count end diff --git a/src/utilities.jl b/src/utilities.jl index 7ad3700..60848f8 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -59,7 +59,7 @@ end @inbounds for i in length(xs):-1:2 xs[i], xs[1] = xs[1], xs[i] xis[i], xis[1] = xis[1], xis[i] - percolate_down!(xs, xis, xs[1], xis[1], i - 1) + percolate_down!(xs, xis, xs[1], xis[1], 1, i - 1) end return end @@ -69,8 +69,9 @@ end xis::AbstractArray, dist::Number, index::Integer, + offset::Integer=1, len::Integer=length(xs)) - i = 1 + i = offset @inbounds while (l = getleft(i)) <= len r = getright(i) j = ifelse(r > len || (xs[l] > xs[r]), l, r) @@ -87,11 +88,6 @@ end return end -# Default skip function, always false -@inline function always_false(::Int) - false -end - # Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} = [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)]::Vector{SVector{dim,T}} diff --git a/test/runtests.jl b/test/runtests.jl index d0ee227..80a917e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("test_knn.jl") include("test_inrange.jl") include("test_monkey.jl") include("test_datafreetree.jl") +include("test_periodic.jl") @testset "views of SVector" begin x = [rand(SVector{3}) for i in 1:20] diff --git a/test/test_knn.jl b/test/test_knn.jl index 3667772..a9e7d2a 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -1,6 +1,6 @@ # Does not test leafsize # Does not test different metrics -import Distances.evaluate +using Distances: evaluate @testset "knn" begin @testset "metric" for metric in [metrics; WeightedEuclidean(ones(2))] diff --git a/test/test_monkey.jl b/test/test_monkey.jl index 8d3182a..ee71a87 100644 --- a/test/test_monkey.jl +++ b/test/test_monkey.jl @@ -31,7 +31,8 @@ dim_data = rand(1:5) size_data = rand(100:151) data = rand(T, dim_data, size_data) - tree = TreeType(data, metric; leafsize = rand(1:15)) + leafsize = rand(1:15) + tree = TreeType(data, metric; leafsize) btree = BruteTree(data, metric) k = rand(1:12) p = rand(dim_data) diff --git a/test/test_periodic.jl b/test/test_periodic.jl new file mode 100644 index 0000000..cd35d59 --- /dev/null +++ b/test/test_periodic.jl @@ -0,0 +1,377 @@ +using Test + +using NearestNeighbors, StaticArrays, Distances + +function create_trees(data, bounds_max, reorder) + kdtree = KDTree(data; leafsize=1, reorder) + balltree = BallTree(data; leafsize=1, reorder) + bounds_min = zeros(length(bounds_max)) + + pkdtree = PeriodicTree(kdtree, bounds_min, bounds_max) + pballtree = PeriodicTree(balltree, bounds_min, bounds_max) + btree = BruteTree(data, PeriodicEuclidean(bounds_max)) + return pkdtree, pballtree, btree +end + +function test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, point, r) + idx_btree = sort(inrange(btree, point, r)) + idx_pkdtree = sort(inrange(pkdtree, point, r)) + idx_pballtree = sort(inrange(pballtree, point, r)) + @test idx_btree == idx_pkdtree == idx_pballtree +end + +function test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, point, k) + idx_btree, dists_btree = knn(btree, point, k, true) + idx_pkdtree, dists_pkdtree = knn(pkdtree, point, k, true) + idx_pballtree, dists_pballtree = knn(pballtree, point, k, true) + + # The key requirement: distances should be equal (this is the main correctness test) + @test dists_btree ≈ dists_pkdtree ≈ dists_pballtree + + # For indices, with ties, different trees may return different valid indices + # We verify that all distances match the expected k-th nearest distance + max_dist_brute = maximum(dists_btree) + max_dist_kd = maximum(dists_pkdtree) + max_dist_ball = maximum(dists_pballtree) + + # All maximum distances should be approximately equal (ensuring same k-th nearest distance) + @test max_dist_brute ≈ max_dist_kd ≈ max_dist_ball + + return dists_pkdtree +end + +function test_data_bounds_point(data, bounds_max, point) + for reorder = (false, true) + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + for k in 1:length(data) + dists = test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, point, k) + r = maximum(dists) + 0.001 + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, point, r) + end + end +end + +data = SVector{2, Float64}.([(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]) +bounds_max = (10.0, 10.0) +point = [8.9, 1.9] +test_data_bounds_point(data, bounds_max, point) + +data = SVector{3, Float64}.([(1, 2, 3), (4, 5, 6), (7, 8, 9), (10, 11, 12), (13, 14, 15)]) +bounds_max = (20.0, 20.0, 20.0) +point = [18.0, 19.0, 0.0] +test_data_bounds_point(data, bounds_max, point) + +# Test mixed periodic/non-periodic dimensions +@testset "Mixed periodic/non-periodic dimensions" begin + # Create simpler data that fits within bounds + data = SVector{2, Float64}.([(1.0, 2.0), (8.0, 9.0)]) + bounds_min = [0.0, 0.0] + bounds_max = [10.0, Inf] # y-dimension is non-periodic + + kdtree = KDTree(data) + ptree = PeriodicTree(kdtree, bounds_min, bounds_max) + + # Test that 3 combinations are generated (3^1 * 1^1 = 3 combinations) + # For periodic x-dimension: [-1, 0, 1] = 3 combinations + # For non-periodic y-dimension: [0] = 1 combination + # Total: 3 combinations + @test length(ptree.combos) == 3 + + # Verify the combinations are correct + combo_values = [combo[1] for combo in ptree.combos] # x-dimension values + @test 0 in combo_values # Should have original box + @test -1 in combo_values # Should have left periodic image + @test 1 in combo_values # Should have right periodic image + + # All y-dimension values should be 0 (non-periodic) + y_values = [combo[2] for combo in ptree.combos] + @test all(y -> y == 0, y_values) + + # Test actual KNN/inrange functionality for mixed dimensions + # Use PeriodicEuclidean as ground truth (with Inf for non-periodic dimensions) + btree = BruteTree(data, PeriodicEuclidean([10.0, Inf])) + + # Test various query points + test_points = [ + [1.5, 2.5], # Near first data point + [8.5, 8.5], # Near second data point + [0.5, 5.0], # Near left boundary (should wrap) + [9.5, 5.0], # Near right boundary (should wrap) + [5.0, 1.0], # Middle x, near bottom y + [5.0, 10.0] # Middle x, near top y + ] + + for query_point in test_points + # Test KNN + for k in 1:length(data) + idx_btree, dists_btree = knn(btree, query_point, k, true) + idx_ptree, dists_ptree = knn(ptree, query_point, k, true) + + # Distances should match (main correctness test) + @test dists_btree ≈ dists_ptree + end + + # Test inrange + for radius in [1.0, 2.0, 5.0, 10.0] + idx_btree = sort(inrange(btree, query_point, radius)) + idx_ptree = sort(inrange(ptree, query_point, radius)) + @test idx_btree == idx_ptree + end + end +end + +# Test comprehensive mixed periodic/non-periodic scenarios +@testset "Comprehensive mixed dimensions" begin + # Test different combinations of periodic/non-periodic dimensions + + # Scenario 1: First dimension periodic, second non-periodic + data1 = SVector{2, Float64}.([(1.0, 2.0), (4.0, 5.0), (7.0, 8.0)]) + bounds_min1 = [0.0, 0.0] + bounds_max1 = [8.0, Inf] + + kdtree1 = KDTree(data1) + ptree1 = PeriodicTree(kdtree1, bounds_min1, bounds_max1) + btree1 = BruteTree(data1, PeriodicEuclidean([8.0, Inf])) + + # Test boundary wrapping behavior + test_points1 = [ + [0.5, 3.0], # Near left boundary + [7.5, 6.0], # Near right boundary + [4.0, 2.0], # Middle + [8.5, 4.0], # Outside right boundary (should wrap to 0.5) + [-0.5, 7.0] # Outside left boundary (should wrap to 7.5) + ] + + for query_point in test_points1 + for k in 1:length(data1) + idx_btree, dists_btree = knn(btree1, query_point, k, true) + idx_ptree, dists_ptree = knn(ptree1, query_point, k, true) + @test dists_btree ≈ dists_ptree + end + + for radius in [1.0, 2.0, 3.0, 5.0] + idx_btree = sort(inrange(btree1, query_point, radius)) + idx_ptree = sort(inrange(ptree1, query_point, radius)) + @test idx_btree == idx_ptree + end + end + + # Scenario 2: First non-periodic, second periodic + data2 = SVector{2, Float64}.([(2.0, 1.0), (5.0, 4.0), (8.0, 7.0)]) + bounds_min2 = [0.0, 0.0] + bounds_max2 = [Inf, 8.0] + + kdtree2 = KDTree(data2) + ptree2 = PeriodicTree(kdtree2, bounds_min2, bounds_max2) + btree2 = BruteTree(data2, PeriodicEuclidean([Inf, 8.0])) + + test_points2 = [ + [3.0, 0.5], # Near bottom boundary + [6.0, 7.5], # Near top boundary + [4.0, 4.0], # Middle + [7.0, 8.5], # Outside top boundary (should wrap to 0.5) + [1.0, -0.5] # Outside bottom boundary (should wrap to 7.5) + ] + + for query_point in test_points2 + for k in 1:length(data2) + idx_btree, dists_btree = knn(btree2, query_point, k, true) + idx_ptree, dists_ptree = knn(ptree2, query_point, k, true) + @test dists_btree ≈ dists_ptree + end + + for radius in [1.0, 2.0, 3.0, 5.0] + idx_btree = sort(inrange(btree2, query_point, radius)) + idx_ptree = sort(inrange(ptree2, query_point, radius)) + @test idx_btree == idx_ptree + end + end + + # Scenario 3: 3D with mixed dimensions + data3 = SVector{3, Float64}.([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0), (7.0, 8.0, 9.0)]) + bounds_min3 = [0.0, 0.0, 0.0] + bounds_max3 = [8.0, Inf, 10.0] # x and z periodic, y non-periodic + + kdtree3 = KDTree(data3) + ptree3 = PeriodicTree(kdtree3, bounds_min3, bounds_max3) + btree3 = BruteTree(data3, PeriodicEuclidean([8.0, Inf, 10.0])) + + test_points3 = [ + [1.5, 3.0, 4.0], # Near first data point + [7.5, 6.0, 9.5], # Near boundaries in periodic dimensions + [4.0, 5.0, 5.0], # Middle + [8.5, 7.0, 10.5] # Outside periodic boundaries + ] + + for query_point in test_points3 + for k in 1:min(2, length(data3)) # Test fewer k values for 3D + idx_btree, dists_btree = knn(btree3, query_point, k, true) + idx_ptree, dists_ptree = knn(ptree3, query_point, k, true) + @test dists_btree ≈ dists_ptree + end + + for radius in [2.0, 4.0, 6.0] + idx_btree = sort(inrange(btree3, query_point, radius)) + idx_ptree = sort(inrange(ptree3, query_point, radius)) + @test idx_btree == idx_ptree + end + end +end + +# Test boundary cases and edge conditions +@testset "Boundary cases and edge conditions" begin + # Test near-boundary points that should find neighbors through periodic wrapping + data = SVector{2, Float64}.([(0.5, 1.0), (9.5, 8.0), (5.0, 5.0)]) + bounds_max = [10.0, 10.0] + + # Query point very close to boundary - should find wrapped neighbors + query_point = [0.1, 1.5] # Very close to (0.5, 1.0) and should also find (9.5, 8.0) through wrapping + + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + + # Test that periodic tree finds same neighbors as brute tree + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, 3) + + # Test with radius that should capture wrapped neighbors + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, 2.0) + end + + # Test point exactly at boundary + query_point = [0.0, 5.0] + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, 2) + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, 1.5) + end + + # Test point exactly at opposite boundary + query_point = [10.0, 5.0] + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, 2) + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, 1.5) + end +end + +# Test with different data distributions +@testset "Different data distributions" begin + # Dense data near boundaries + data = SVector{2, Float64}.([(0.1, 0.1), (0.2, 0.3), (9.8, 9.9), (9.7, 9.6), (5.0, 5.0)]) + bounds_max = [10.0, 10.0] + + query_points = [ + [0.0, 0.0], # Corner + [10.0, 10.0], # Opposite corner + [0.15, 0.2], # Near dense cluster + [9.75, 9.75] # Near other dense cluster + ] + + for query_point in query_points + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + + # Test all k values + for k in 1:length(data) + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, k) + end + + # Test multiple radii + for radius in [0.5, 1.0, 2.0, 5.0] + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, radius) + end + end + end +end + +# Test specific periodic scenarios with known correct answers +@testset "Periodic boundary verification" begin + # Simple case: two points that should be closest through periodic boundary + data = SVector{2, Float64}.([(0.5, 5.0), (9.5, 5.0)]) + bounds_max = [10.0, 10.0] + + # Query at x=0.0 should find (0.5, 5.0) as closest, but (9.5, 5.0) should be very close too via periodicity + query_point = [0.0, 5.0] + + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + + # Test KNN - both trees should give same results as brute tree + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, 2) + + # Test inrange with small radius that should capture periodic neighbor + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, 1.0) + end + + # Test another scenario: query outside the box should wrap around + query_point = [10.5, 5.0] # Should be equivalent to [0.5, 5.0] due to periodicity + + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, 2) + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, 2.0) + end +end + +# Test extensive periodic scenarios +@testset "Extensive periodic testing" begin + # Larger dataset with more complex periodic interactions + data = SVector{2, Float64}.([(0.1, 0.1), (2.0, 3.0), (5.0, 5.0), (8.0, 2.0), (9.9, 9.9)]) + bounds_max = [10.0, 10.0] + + # Test many query points systematically + test_points = [ + [0.0, 0.0], # Corner + [5.0, 5.0], # Center + [10.0, 0.0], # Corner + [0.0, 10.0], # Corner + [10.0, 10.0], # Corner + [0.05, 0.05], # Very close to boundary + [9.95, 9.95], # Very close to opposite boundary + [11.0, 1.0], # Outside box (should wrap) + [-1.0, 9.0] # Outside box (should wrap) + ] + + for query_point in test_points + for reorder in [false, true] + pkdtree, pballtree, btree = create_trees(data, bounds_max, reorder) + + # Test KNN for various k values + for k in 1:min(3, length(data)) + test_periodic_euclidean_against_brute_knn(pkdtree, pballtree, btree, query_point, k) + end + + # Test inrange for various radii + for radius in [0.5, 1.0, 2.0, 4.0] + test_periodic_euclidean_against_brute_inrange(pkdtree, pballtree, btree, query_point, radius) + end + end + end +end + +# Test data validation +@testset "Data validation" begin + # Test that data outside bounds is rejected + data_good = SVector{2, Float64}.([(1.0, 2.0), (3.0, 4.0)]) + data_bad = SVector{2, Float64}.([(1.0, 2.0), (11.0, 4.0)]) # 11.0 > 10.0 + + kdtree_good = KDTree(data_good) + kdtree_bad = KDTree(data_bad) + + bounds_min = [0.0, 0.0] + bounds_max = [10.0, 10.0] + + # Should work with good data + @test isa(PeriodicTree(kdtree_good, bounds_min, bounds_max), PeriodicTree) + + # Should fail with bad data + @test_throws ArgumentError PeriodicTree(kdtree_bad, bounds_min, bounds_max) + + # Test dimension mismatch + @test_throws ArgumentError PeriodicTree(kdtree_good, [0.0], bounds_max) + @test_throws ArgumentError PeriodicTree(kdtree_good, bounds_min, [10.0]) + + # Test invalid box dimensions + @test_throws ArgumentError PeriodicTree(kdtree_good, [0.0, 0.0], [-1.0, 10.0]) + @test_throws ArgumentError PeriodicTree(kdtree_good, [5.0, 0.0], [3.0, 10.0]) +end