diff --git a/Project.toml b/Project.toml index d0071f5..fa6f2a9 100644 --- a/Project.toml +++ b/Project.toml @@ -12,10 +12,12 @@ StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0" julia = "1.6" [extras] +GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "Mmap", "Tensors", "Test"] +test = ["GeometryBasics", "LinearAlgebra", "Mmap", "StableRNGs", "Tensors", "Test"] diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 28d998d..fe8d132 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -56,10 +56,13 @@ include("inrange.jl") include("hyperspheres.jl") include("hyperrectangles.jl") include("utilities.jl") +include("tree_ops.jl") +export root, treeindex, eachtreeindex, leafpoints, leaf_points_indices, region, isleaf, isroot, skip_regions, children, parent, nextsibling, prevsibling, points + include("brute_tree.jl") include("kd_tree.jl") include("ball_tree.jl") -include("tree_ops.jl") + for dim in (2, 3) tree = KDTree(rand(dim, 10)) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 5b10e17..9de57d7 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -140,6 +140,24 @@ function _knn(tree::BallTree, return end +@inline function region(T::BallTree) + if length(T.hyper_spheres) == 0 + return _infinite_hypersphere(eltype(T.hyper_spheres)) + else + return T.hyper_spheres[1] + end +end +@inline function _split_regions(tree::BallTree, ::HyperSphere, index::Int) + # tree = tr[] + r1 = tree.hyper_spheres[getleft(index)] + r2 = tree.hyper_spheres[getright(index)] + return r1, r2 +end +@inline function _parent_region(tree::BallTree, ::HyperSphere, index::Int) + # tree = tr[] + parent = getparent(index) + return tree.hyper_spheres[parent] +end function knn_kernel!(tree::BallTree{V}, index::Int, @@ -179,20 +197,17 @@ function _inrange(tree::BallTree{V}, radius::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + return inrange_kernel!(tree, root(tree), point, ball, idx_in_ball) # Call the recursive range finders end -function inrange_kernel!(tree::BallTree, - index::Int, +function inrange_kernel!(tree::BallTree, + node::NNTreeNode, point::AbstractVector, query_ball::HyperSphere, idx_in_ball::Union{Nothing, Vector{<:Integer}}) - if index > length(tree.hyper_spheres) - return 0 - end - - sphere = tree.hyper_spheres[index] + sphere = region(node) + # tree = NearestNeighbors.tree(node) # give fully specified function name to avoid # If the query ball in the bounding sphere for the current sub tree # do not intersect we can disrecard the whole subtree @@ -201,8 +216,8 @@ function inrange_kernel!(tree::BallTree, end # At a leaf node, check all points in the leaf node - if isleaf(tree.tree_data.n_internal_nodes, index) - return add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true) + if isleaf(tree, node) + return add_points_inrange!(idx_in_ball, tree, treeindex(node), point, query_ball.r, true) end count = 0 @@ -210,11 +225,12 @@ function inrange_kernel!(tree::BallTree, # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses(tree.metric, sphere, query_ball) - count += addall(tree, index, idx_in_ball) + count += addall(tree, node, idx_in_ball) else # Recursively call the left and right sub tree. - count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + left, right = children(tree, node) + count += inrange_kernel!(tree, left, point, query_ball, idx_in_ball) + count += inrange_kernel!(tree, right, point, query_ball, idx_in_ball) end return count end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index ed5ce9a..439feeb 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -58,6 +58,13 @@ function knn_kernel!(tree::BruteTree{V}, end end +# Custom implementation for BruteTree +isleaf(_::BruteTree, _::NNTreeNode) = true +leafpoints(tree::BruteTree, _::NNTreeNode) = tree.data +leaf_points_indices(tree::BruteTree, _::NNTreeNode) = eachindex(tree.data) +eachtreeindex(_::BruteTree) = 1:0 # empty list... +region(tree::BruteTree) = compute_bbox(tree.data) + function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, diff --git a/src/hyperspheres.jl b/src/hyperspheres.jl index 7d73300..807d4be 100644 --- a/src/hyperspheres.jl +++ b/src/hyperspheres.jl @@ -8,6 +8,13 @@ end HyperSphere(center::SVector{N,T1}, r) where {N, T1} = HyperSphere(center, convert(T1, r)) HyperSphere(center::AbstractVector, r) = HyperSphere(SVector{length(center)}(center), r) +function _infinite_hypersphere(::Type{HyperSphere{N,T}}) where {N, T} + return HyperSphere{N,T}( + ntuple(i->zero(T), Val(N)), + convert(T, Inf) + ) +end + @inline function intersects(m::Metric, s1::HyperSphere{N}, s2::HyperSphere{N}) where {N} diff --git a/src/kd_tree.jl b/src/kd_tree.jl index a6326ff..d8e430e 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -5,6 +5,7 @@ struct KDTree{V <: AbstractVector, M <: MinkowskiMetric, T, TH} <: NNTree{V,M} metric::M split_vals::Vector{T} split_dims::Vector{UInt16} + split_minmax::Vector{Tuple{T,T}} tree_data::TreeData reordered::Bool end @@ -30,6 +31,7 @@ function KDTree(data::AbstractVector{V}, indices = collect(1:n_p) split_vals = Vector{eltype(V)}(undef, tree_data.n_internal_nodes) split_dims = Vector{UInt16}(undef, tree_data.n_internal_nodes) + split_minmax = Vector{Tuple{eltype(V),eltype(V)}}(undef, tree_data.n_internal_nodes) if reorder indices_reordered = Vector{Int}(undef, n_p) @@ -56,7 +58,7 @@ function KDTree(data::AbstractVector{V}, hyper_rec = compute_bbox(data) # Call the recursive KDTree builder - build_KDTree(1, data, data_reordered, hyper_rec, split_vals, split_dims, indices, indices_reordered, + build_KDTree(1, data, data_reordered, hyper_rec, split_vals, split_dims, split_minmax, indices, indices_reordered, 1:length(data), tree_data, reorder) if reorder data = data_reordered @@ -71,7 +73,7 @@ function KDTree(data::AbstractVector{V}, end end - KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, tree_data, reorder) + KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, split_minmax, tree_data, reorder) end function KDTree(data::AbstractVecOrMat{T}, @@ -97,6 +99,7 @@ function build_KDTree(index::Int, hyper_rec::HyperRectangle, split_vals::Vector{T}, split_dims::Vector{UInt16}, + split_minmax::Vector{Tuple{T,T}}, indices::Vector{Int}, indices_reordered::Vector{Int}, range, @@ -129,18 +132,21 @@ function build_KDTree(index::Int, split_vals[index] = split_val split_dims[index] = split_dim + split_minmax[index] = (hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim]) # Call the left sub tree with an updated hyper rectangle new_maxes = @inbounds setindex(hyper_rec.maxes, split_val, split_dim) hyper_rec_left = HyperRectangle(hyper_rec.mins, new_maxes) build_KDTree(getleft(index), data, data_reordered, hyper_rec_left, split_vals, split_dims, - indices, indices_reordered, first(range):mid_idx - 1, tree_data, reorder) + split_minmax, indices, indices_reordered, + first(range):mid_idx - 1, tree_data, reorder) # Call the right sub tree with an updated hyper rectangle new_mins = @inbounds setindex(hyper_rec.mins, split_val, split_dim) hyper_rec_right = HyperRectangle(new_mins, hyper_rec.maxes) build_KDTree(getright(index), data, data_reordered, hyper_rec_right, split_vals, split_dims, - indices, indices_reordered, mid_idx:last(range), tree_data, reorder) + split_minmax, indices, indices_reordered, mid_idx:last(range), + tree_data, reorder) end @@ -204,17 +210,48 @@ function knn_kernel!(tree::KDTree{V}, return end +@inline function region(T::KDTree) + return T.hyper_rec +end + +@inline function _split_regions(T::KDTree, R::HyperRectangle, index::Int) + # T = tr[] + split_val = T.split_vals[index] + split_dim = T.split_dims[index] + + r1 = HyperRectangle(R.mins, @inbounds setindex(R.maxes, split_val, split_dim)) + r2 = HyperRectangle(@inbounds(setindex(R.mins, split_val, split_dim)), R.maxes) + return r1, r2 +end + +@inline function _parent_region(T::KDTree, R::HyperRectangle, index::Int) + # T = tr[] + parent = getparent(index) + split_dim = T.split_dims[parent] + dimmin,dimmax = T.split_minmax[parent] + if getleft(parent) == index + r = HyperRectangle( + R.mins, @inbounds setindex(R.maxes, dimmax, split_dim) + ) + else + r = HyperRectangle( + @inbounds(setindex(R.mins, dimmin, split_dim)), R.maxes + ) + end + return r +end + function _inrange(tree::KDTree, point::AbstractVector, radius::Number, idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[]) init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point) return inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball, - tree.hyper_rec, init_min) + tree.hyper_rec, init_min) end # Explicitly check the distance between leaf node and point while traversing -function inrange_kernel!(tree::KDTree, +function inrange_kernel!(tree::KDTree, index::Int, point::AbstractVector, r::Number, @@ -270,3 +307,57 @@ function inrange_kernel!(tree::KDTree, count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min) return count end + + +# Explicitly check the distance between leaf node and point while traversing +function inrange_kernel!(node::NNTreeNode, + point::AbstractVector, + r::Number, + idx_in_ball::Union{Nothing, Vector{<:Integer}}, + min_dist) + # Point is outside hyper rectangle, skip the whole sub tree + if min_dist > r + return 0 + end + + # At a leaf node. Go through all points in node and add those in range + if isleaf(tree, node) + return add_points_inrange!(idx_in_ball, tree, node.index, point, r, false) + end + + left, right = children(tree, node) + M = tree.metric + index = treeindex(node) + + split_val = tree.split_vals[index] + split_dim = tree.split_dims[index] + p_dim = point[split_dim] + split_diff = p_dim - split_val + + count = 0 + + if split_diff > 0 # Point is to the right of the split value + close = right + far = left + ddiff = max(zero(p_dim - hi), p_dim - hi) + else # Point is to the left of the split value + close = left + far = right + ddiff = max(zero(lo - p_dim), lo - p_dim) + end + # Call closer sub tree + count += inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist) + + # TODO: We could potentially also keep track of the max distance + # between the point and the hyper rectangle and add the whole sub tree + # in case of the max distance being <= r similarly to the BallTree inrange method. + # It would be interesting to benchmark this on some different data sets. + + # Call further sub tree with the new min distance + split_diff_pow = eval_pow(M, split_diff) + ddiff_pow = eval_pow(M, ddiff) + diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) + new_min = eval_reduce(M, min_dist, diff_tot) + count += inrange_kernel!(tree, far, point, r, idx_in_ball, new_min) + return count +end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 51203cd..93e4abd 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -12,12 +12,323 @@ function show(io::IO, tree::NNTree{V}) where {V} print(io, " Reordered: ", tree.reordered) end +struct NNTreeNode{R} + index::Int + region::R +end + +# Show the info associated with the node. +function show(io::IO, node::NNTreeNode) + #println(io, typeof(tree(node))) + print(io, "NNTreeNode: treeindex:", (treeindex(node)), + " region: ", region(node)) + #println(io, " region: ", region(node)) +end + +""" + treeindex(node) + +This returns the index of the given node. The indices of nodes are an +implementation specific feature but are externally useful to +associate metadata with nodes within the search tree. +The range of indices is given by `eachindex(node)`. +Nodes can be outside the range if they are leaf nodes. + +## Example +```julia +function walktree(tree, node) + println("Node index: ", treeindex(node), " and isleaf:", isleaf(tree, node) ) + if !isleaf(tree, node) + walktree.(tree, children(node)) + end +end +using StableRNGs, GeometryBasics, NearestNeighbors +T = KDTree(rand(StableRNG(1), Point2f, 25)) +println("eachtreeindex: ", eachtreeindex(root(T))) +walktree(tree, root(T)) +``` + +## See Also +[`eachindex`](@ref) +""" +@inline treeindex(node::NNTreeNode) = node.index + +""" + eachtreeindex(node) + +Get the full range of indices associated with the nodes of the search +tree, this only depends on the tree the node is associated with, so all +nodes of that tree will return the same thing. The index range only +corresponds to the internal nodes of the tree. + +## See Also +[`treeindex`](@ref) +""" +@inline eachtreeindex(tree::NNTree) = 1:tree.tree_data.n_internal_nodes + +""" + isleaf(tree, node) + +Return true if the node is a leaf node of a tree. +""" +@inline isleaf(tree::NNTree, node::NNTreeNode) = isleaf(tree.tree_data.n_internal_nodes, treeindex(node)) + +""" + isroot(tree, node) + +Return true if the node is a root node of the tree. +""" +@inline function isroot(_, node) + return node.index == 1 +end + +""" + region(node) + +Return the region of space associated with a node in the tree. +""" +@inline region(node::NNTreeNode) = node.region + +""" + children(tree, node) + +Return the children of a given node in the tree. +This throws an BoundsError if the node is a leaf. +""" +@inline function children(tree::NNTree, node::NNTreeNode) + if isleaf(tree, node) + throw(ArgumentError("Cannot call children on leaf nodes")) + end + i = treeindex(node) + r1, r2 = _split_regions(tree, region(node), i) + i1, i2 = getleft(i), getright(i) + return ( + NNTreeNode(i1, r1), + NNTreeNode(i2, r2) + ) +end + +@inline function parent(tree::NNTree, node::NNTreeNode) + i = treeindex(node) + p = getparent(i) + if p == 0 + throw(ArgumentError("Cannot call parent on the root node")) + end + r = _parent_region(tree, region(node), i) + return ( + NNTreeNode(p, r) + ) +end + +@inline function nextsibling(tree, node::NNTreeNode) + if isroot(tree, node) + return nothing + else + p = parent(tree, node) + l, r = children(tree, p) + if node == l + return r + else + return nothing + end + end +end + +@inline function prevsibling(tree, node::NNTreeNode) + if isroot(tree, node) + return nothing + else + p = parent(tree, node) + l, r = children(tree, p) + if node == r + return l + else + return nothing + end + end +end + +""" +This function enables one to disable region computation by providing +a nothing type for that. it'll just omit the region computation +entirely. + +You can enable it with +skip_regions(root) +""" +_split_regions(_::NNTree, r::Nothing, _) = nothing, nothing +_parent_region(_::NNTree, r::Nothing, _) = nothing + +""" + skip_regions(node) + +Sometimes all you need to navigate the nearest neighbor tree is +the tree structure itself and not the regions associated with each +node. In some cases, computing the regions can be expensive. So +this call sets regions to `nothing` which propagates throughout +the tree and simply elides the region computations. + +## Example +```julia +using BenchmarkTools, StableRNGs, GeometryBasics +function count_points(tree, node) + count = 0 + if NearestNeighbors.isleaf(tree, node) + count += length(NearestNeighbors.points_indices(tree, node)) + else + left, right = NearestNeighbors.children(tree, node) + count += count_points(tree, left) + count += count_points(tree, right) + end + return count + end +end +pts = rand(StableRNG(1), Point2f, 1_000_000) +T = KDTree(pts) +@btime count_points(T, root(T)) +@btime count_points(T, skip_regions(root(T)) +``` +""" +@inline skip_regions(node::NNTreeNode) = NNTreeNode(treeindex(node), nothing) + + +""" + root(T::NNTree) + +Return the root node of the nearest neighbor search tree. +""" +function root(T::NNTree) + return NNTreeNode(1, region(T)) +end + +function _points(tree_data, data, index, indices, reordered) + if reordered + return (data[idx] for idx in get_leaf_range(tree_data, index)) + else + return (data[indices[idx]] for idx in get_leaf_range(tree_data, index)) + end +end + +""" + points(tree, node) + +Create an iterator for all the points contained within the + node of the nearest neighbor tree. +""" +function points(T::NNTree, N::NNTreeNode) + return PointsIterator(T, skip_regions(N)) +end + +struct PointsIterator{Tree <: NNTree, N <: NNTreeNode} + tree::Tree + start::N +end + +function _find_first_leaf(tree, node) + while true + if isleaf(tree, node) + return node + else + node = children(tree, node)[1] # get the left child + end + end +end + +function _find_next_leaf(tree, node, stopindex) + if node.index == stopindex + return nothing + end + next = nextsibling(tree, node) + while next === nothing + if isroot(tree, node) + return nothing # this is the end condition... + end + node = parent(tree, node) + if node.index == stopindex + return nothing + end + + next = nextsibling(tree, node) + end + # now we need to find the leaf from this node... + leaf = _find_first_leaf(tree, next) + return leaf +end + +function _next_leaf_and_leaf_iterate(tree, node, stopindex) + while true + node = _find_next_leaf(tree, node, stopindex) + if node === nothing + return nothing + end + leafrange = get_leaf_range(tree.tree_data, treeindex(node)) + next = iterate(leafrange) + if next !== nothing + return next, leafrange, node + end + end +end + +function _iterate(leafrangenext, leafrange, node, it) + if leafrangenext === nothing + next = _next_leaf_and_leaf_iterate(it.tree, node, treeindex(it.start)) + + if next === nothing + # if next is still nothing, then we are an empty tree + return nothing + end + + leafrangenext, leafrange, node = next + end + + idx, leafrangestate = leafrangenext + reordered = it.tree.reordered + if reordered + pt = it.tree.data[idx] + else + pt = it.tree.data[it.tree.indices[idx]] + end + + return pt, (leafrangestate, leafrange, node) +end + +import Base.iterate, Base.IteratorSize, Base.eltype + +IteratorSize(_::PointsIterator) = Base.SizeUnknown() +eltype(it::PointsIterator) = eltype(it.tree.data) + +function iterate(it::PointsIterator) + node = _find_first_leaf(it.tree, it.start) + leafrange = get_leaf_range(it.tree.tree_data, treeindex(node)) + leafrangenext = iterate(leafrange) + + return _iterate(leafrangenext, leafrange, node, it) +end +function iterate(it::PointsIterator, state) + leafrangestate, leafrange, node = state + leafrangenext = iterate(leafrange, leafrangestate) + return _iterate(leafrangenext, leafrange, node, it) +end + + +function leafpoints(T::NNTree, node::NNTreeNode) + # T = tree(node) + # redirect to possibly specialize + return _points(T.tree_data, T.data, treeindex(node), T.indices, T.reordered) +end + +function leaf_points_indices(T::NNTree, node::NNTreeNode) + # T = tree(node) + tree_data = T.tree_data + indices = T.indices + return (indices[idx] for idx in get_leaf_range(tree_data, treeindex(node))) +end + # We split the tree such that one of the sub trees has exactly 2^p points # and such that the left sub tree always has more points. # This means that we can deterministally (with just some comparisons) # find if we are at a leaf node and how many function find_split(low, leafsize, n_p) - # The number of leafs node left in the tree, # use `ceil` to count a partially filled node as 1. n_leafs = ceil(Int, n_p / leafsize) @@ -145,3 +456,24 @@ function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:I end return count end + +# Add all points in this subtree since we have determined +# they are all within the desired range +function addall(tree::NNTree, node::NNTreeNode, idx_in_ball::Union{Nothing, Vector{<:Integer}}) + # tree = NearestNeighbors.tree(node) + tree_data = tree.tree_data + count = 0 + index = node.index + if isleaf(tree, node) + for z in get_leaf_range(tree_data, index) + idx = tree.reordered ? z : tree.indices[z] + count += 1 + idx_in_ball !== nothing && push!(idx_in_ball, idx) + end + else + left, right = children(tree, node) + count += addall(tree, left, idx_in_ball) + count += addall(tree, right, idx_in_ball) + end + return count +end diff --git a/test/runtests.jl b/test/runtests.jl index 584a271..0e27079 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ const fullmetrics = [metrics; Hamming(); CustomMetric1(); CustomMetric2()] const trees = [KDTree, BallTree] const trees_with_brute = [BruteTree; trees] +include("test_walk.jl") include("test_knn.jl") include("test_inrange.jl") include("test_monkey.jl") diff --git a/test/test_walk.jl b/test/test_walk.jl new file mode 100644 index 0000000..16a75c7 --- /dev/null +++ b/test/test_walk.jl @@ -0,0 +1,103 @@ +using StableRNGs, GeometryBasics + +@testset "tree type" for TreeType in trees_with_brute + @testset "type" for T in (Float32, Float64) + @testset "tree walking" begin + allpts = rand(StableRNG(1), Point2{T}, 1000) + tree = TreeType(allpts) + + function _find_leaf(node) + if isleaf(tree, node) + return node + else + left, _ = children(tree, node) + return _find_leaf(left) + end + end + + leafnode = _find_leaf(root(tree)) + # test that children throws an error on a leaf node + @test_throws ArgumentError children(tree, leafnode) + + function _find_in_node!(node, indices, pts) + if isleaf(tree, node) + for (point, index) in zip(leafpoints(tree, node), leaf_points_indices(tree, node)) + @test point == allpts[index] + end + + for point in leafpoints(tree, node) + push!(pts, point) + end + for index in leaf_points_indices(tree, node) + push!(indices, index) + end + else + left, right = children(tree, node) + _find_in_node!( left, indices, pts) + _find_in_node!(right, indices, pts) + end + end + function find_all_points(root, allpts) + # walk to find all points + indices = Set{Int64}() + pts = Set{Point2{T}}() + _find_in_node!(root, indices, pts) + @test indices == Set(eachindex(allpts)) + @test pts == Set(allpts) + end + + find_all_points(root(tree), allpts) + find_all_points(skip_regions(root(tree)), allpts) + end + @testset "region containment" begin + allpts = rand(StableRNG(2), Point2{T}, 1000) + tree = TreeType(allpts) + + function _contains(pt, region::NearestNeighbors.HyperRectangle) + for i in eachindex(pt) + if pt[i] < region.mins[i] || pt[i] > region.maxes[i] + return false + end + end + return true + end + function _contains(pt, region::NearestNeighbors.HyperSphere) + center = region.center + radius = region.r + dist = norm(pt - center) + return dist <= radius + end + + function _contains(subregion::NearestNeighbors.HyperSphere, region::NearestNeighbors.HyperSphere) + # check that the center of the subregion is in the region + return _contains(subregion.center, region) + end + + function _contains(subregion::NearestNeighbors.HyperRectangle, region::NearestNeighbors.HyperRectangle) + return _contains(subregion.mins, region) && + _contains(subregion.maxes, region) + end + + function check_containment(node) + r = region(node) + if isleaf(tree, node) + for point in leafpoints(tree, node) + @test _contains(point, r) + end + else + # double check all the points are within the region + for point in points(tree, node) + @test _contains(point, r) + end + left, right = children(tree, node) + # check + @test _contains(region(left), r) + @test _contains(region(right), r) + end + end + + check_containment(root(tree)) + end + end +end +