Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
julia = "1.6"

[extras]
GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "Mmap", "Tensors", "Test"]
test = ["GeometryBasics", "LinearAlgebra", "Mmap", "StableRNGs", "Tensors", "Test"]
7 changes: 6 additions & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distmat, npairs
export injectdata

import Base.eachindex

export Euclidean,
Cityblock,
Minkowski,
Expand Down Expand Up @@ -56,10 +58,13 @@ include("inrange.jl")
include("hyperspheres.jl")
include("hyperrectangles.jl")
include("utilities.jl")
include("tree_ops.jl")
export root, treeindex, leafpoints, leaf_points_indices, region, isleaf, skip_regions, children

include("brute_tree.jl")
include("kd_tree.jl")
include("ball_tree.jl")
include("tree_ops.jl")


for dim in (2, 3)
tree = KDTree(rand(dim, 10))
Expand Down
35 changes: 22 additions & 13 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ function _knn(tree::BallTree,
return
end

@inline function region(T::BallTree)
if length(T.hyper_spheres) == 0
return _infinite_hypersphere(eltype(T.hyper_spheres))
else
return T.hyper_spheres[1]
end
end
@inline function _split_regions(tree::BallTree, ::HyperSphere, index::Int)
r1 = tree.hyper_spheres[getleft(index)]
r2 = tree.hyper_spheres[getright(index)]
return r1, r2
end

function knn_kernel!(tree::BallTree{V},
index::Int,
Expand Down Expand Up @@ -179,20 +191,16 @@ function _inrange(tree::BallTree{V},
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return inrange_kernel!(root(tree), point, ball, idx_in_ball) # Call the recursive range finders
end

function inrange_kernel!(tree::BallTree,
index::Int,
function inrange_kernel!(node::NNTreeNode,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})

if index > length(tree.hyper_spheres)
return 0
end

sphere = tree.hyper_spheres[index]
sphere = region(node)
tree = node.tree

# If the query ball in the bounding sphere for the current sub tree
# do not intersect we can disrecard the whole subtree
Expand All @@ -201,20 +209,21 @@ function inrange_kernel!(tree::BallTree,
end

# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
return add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true)
if isleaf(node)
return add_points_inrange!(idx_in_ball, tree, treeindex(node), point, query_ball.r, true)
end

count = 0

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses(tree.metric, sphere, query_ball)
count += addall(tree, index, idx_in_ball)
count += addall(node, idx_in_ball)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
left, right = children(node)
count += inrange_kernel!( left, point, query_ball, idx_in_ball)
count += inrange_kernel!(right, point, query_ball, idx_in_ball)
end
return count
end
7 changes: 7 additions & 0 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ function knn_kernel!(tree::BruteTree{V},
end
end

# Custom implementation for BruteTree
isleaf(node::NNTreeNode{T,R}) where {T <: BruteTree, R} = true
leafpoints(node::NNTreeNode{T,R}) where {T <: BruteTree, R} = tree(node).data
leaf_points_indices(node::NNTreeNode{T,R}) where {T <: BruteTree, R} = eachindex(tree(node).data)
eachindex(node::NNTreeNode{T,R}) where {T <: BruteTree, R} = 1:0 # empty list...
region(tree::BruteTree) = compute_bbox(tree.data)

function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
Expand Down
7 changes: 7 additions & 0 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ end
HyperSphere(center::SVector{N,T1}, r) where {N, T1} = HyperSphere(center, convert(T1, r))
HyperSphere(center::AbstractVector, r) = HyperSphere(SVector{length(center)}(center), r)

function _infinite_hypersphere(::Type{HyperSphere{N,T}}) where {N, T}
return HyperSphere{N,T}(
ntuple(i->zero(T), Val(N)),
convert(T, Inf)
)
end

@inline function intersects(m::Metric,
s1::HyperSphere{N},
s2::HyperSphere{N}) where {N}
Expand Down
13 changes: 13 additions & 0 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ function knn_kernel!(tree::KDTree{V},
return
end

@inline function region(T::KDTree)
return T.hyper_rec
end

@inline function _split_regions(T::KDTree, R::HyperRectangle, index::Int)
split_val = T.split_vals[index]
split_dim = T.split_dims[index]

r1 = HyperRectangle(R.mins, @inbounds setindex(R.maxes, split_val, split_dim))
r2 = HyperRectangle(@inbounds(setindex(R.mins, split_val, split_dim)), R.maxes)
return r1, r2
end

function _inrange(tree::KDTree,
point::AbstractVector,
radius::Number,
Expand Down
190 changes: 190 additions & 0 deletions src/tree_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,175 @@ function show(io::IO, tree::NNTree{V}) where {V}
print(io, " Reordered: ", tree.reordered)
end

struct NNTreeNode{T <: NNTree, R}
index::Int
tree::T
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a general "philosophy" that storing something big (a full KD Tree) in something that is conceptually small (a tree node) is often a mistake.

As you traverse the tree you will create all these nodes that will all contain the same tree. What do you think about dropping the tree field and instead require a user to provide the tree a an argument to the traverse functions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question and good rationale. My own experience has been that Julia is very good at optimizing the codes when the types are immutable, so I doubt it is really creating different copies if you use it in a function.

My argument for the current organization is that node ids are tied to the tree and so this makes it so that you don't have an additional argument hanging out everywhere..., it makes it easy and simple to write codes that do the right thing and get the answer right. But as I said, I hadn't considered your particular perspective here.

Is there a test we could do to resolve if this is an issue? (i.e to convince me that your perspective is correct, or for me to convince you it isn't a problem to store the tree and the compiler really is smart enough?)

Maybe, vectors of nodes would be bad for including the tree? But we do we ever actually need them?

Another argument for keeping it linked is that the AbstractTrees interface is 'node' oriented, so you define children, parent, etc. on a node level; which would require keeping the tree as part of the struct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay -- you are right -- this does make a big difference. I took a trivial walk the list and count up the sizes of the leaves code that is just going to benchmark the traversal... (total number of points 100k) By storing the NNTree variable it takes ~131 μs. If I just do it by raw calls with node ids and passing the tree as a parameter to the function, it takes ~29 μs. But... if I store a ref to the tree rather than the full tree structure, then I get all the functionality and it takes ~45 μs. I think the latter is worth doing. So I'll implement that and update the pull request. Not that all of this skips the region computations for the KDTree, so that will shorten the difference.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But... if I store a ref to the tree rather than the full tree structure

I don't fully understand what that means.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the updated structure. This stores a pointer to the tree information instead of a copy of all the information.

struct NNTreeNode{TreeRef <: Ref{ <: NNTree }, R}
     index::Int
     treeref::TreeRef
     region::R
 end 

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue I have with the iterate analogy is that iterate is designed to execute within a single function context -- and has some nice syntax to hide the complexity and different types of objects -- whereas most of the tree walking functions are designed to execute recursively, where there is no such affordance that I know of. So you'll have to pass the tree structure to any subfunction -- as well as the node structure.

The current design is just designed to be easy to use; it's also feasible to adapt to the AbstractTrees.jl interface (although I haven't done that yet...) where they do the same thing with parent/children/etc. functions.

But it seems like you are still leaning against it enough though there is minimal overhead, is that correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be precise, the interface you would like is:

children(T::Tree, n::Node) -> (nl::Node, nr::Node)
parent(T::Tree, n::Node) -> (p::Node)
region(n::Node) -> 
leaf_points(T::Tree, n::Node) -> something that iterates over points in the leaf node
etc...

where node is something simple like:

struct Node{R} 
  index::Int
  region::R
end 

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick nudge on this question of interface. Would love to get this wrapped up in the next week or so before some obligations for school starts.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, since I had a moment, I just implemented the interface above. As a check, we can do non-recursive exploration of the tree using the current children, parent, next/prev sibling structure, see, the e.g. points iterator...

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry for the slow response here and sorry for being a bit "annoying" with trying to figure out the "best" interface to use.

A reason for this is that this is my first Julia package so it holds a bit of a special place in my heart and I have also worked quite a bit to reduce memory footprint and improve performance.

I can add your package so it is tested as part of the CI here (and you could then at any time also implement whatever tree walking interface you want there and it will not be broken, or at least it can be updated if changes are made here that would be incompatible with it).

region::R
end

# Show the info associated with the node.
function show(io::IO, node::NNTreeNode)
println(io, typeof(tree(node)))
println(io, " Region: ", region(node))
end



"""
tree(node)

Return the nearest neighbor search tree associated with the given node.
"""
@inline tree(node::NNTreeNode) = node.tree

"""
treeindex(node)

This returns the index of the given node. The indices of nodes are an
implementation specific feature but are externally useful to
associate metadata with nodes within the search tree.
The range of indices is given by `eachindex(node)`.
Nodes can be outside the range if they are leaf nodes.

## Example
```julia
function walktree(node)
println("Node index: ", treeindex(node), " and isleaf:", isleaf(ndoe) )
if !isleaf(node)
walktree.(children(node))
end
end
using StableRNGs, GeometryBasics, NearestNeighbors
T = KDTree(rand(StableRNG(1), Point2f, 25))
println("eachindex: ", eachindex(root(T)))
walktree(root(T))
```

## See Also
[`eachindex`](@ref)
"""
@inline treeindex(node::NNTreeNode) = node.index

"""
eachindex(node)

Get thee full range of indices associated with the nodes of the search
tree, this only depends on the tree the node is associated with, so all
nodes of that tree will return the same thing. The index range only
corresponds to the internal nodes of the tree.

## See Also
[`treeindex`](@ref)
"""
@inline eachindex(node::NNTreeNode) = 1:tree(node).tree_data.n_internal_nodes

"""
isleaf(node)

Return true if the node is a leaf node of a tree.
"""
@inline isleaf(node::NNTreeNode) = isleaf(tree(node).tree_data.n_internal_nodes, treeindex(node))

"""
region(node)

Return the region of space associated with a node in the tree.
"""
@inline region(node::NNTreeNode) = node.region

"""
children(node)

Return the children of a given node in the tree.
This throws an BoundsError if the node is a leaf.
"""
@inline function children(node::NNTreeNode)
if isleaf(node)
throw(BoundsError("Cannot call children on leaf nodes"))
end
T = tree(node)
i = treeindex(node)
r1, r2 = _split_regions(T, region(node), i)
i1, i2 = getleft(i), getright(i)
return (
NNTreeNode(i1, T, r1),
NNTreeNode(i2, T, r2)
)
end

"""
This function enables one to disable region computation by providing
a nothing type for that. it'll just omit the region computation
entirely.

You can enable it with
skip_regions(root)
"""
_split_regions(T::NNTree, r::Nothing, _) = nothing, nothing

"""
skip_regions(node)

Sometimes all you need to navigate the nearest neighbor tree is
the tree structure itself and not the regions associated with each
node. In some cases, computing the regions can be expensive. So
this call sets regions to `nothing` which propagates throughout
the tree and simply elides the region computations.

## Example
```julia
using BenchmarkTools, StableRNGs, GeometryBasics
function count_points(node)
count = 0
if NearestNeighbors.isleaf(node)
count += length(NearestNeighbors.points_indices(node))
else
left, right = NearestNeighbors.children(node)
count += count_points(left)
count += count_points(right)
end
return count
end
end
pts = rand(StableRNG(1), Point2f, 1_000_000)
T = KDTree(pts)
@btime count_points(root(T))
@btime count_points(skip_regions(root(T))
```
"""
@inline skip_regions(node::NNTreeNode) = NNTreeNode(treeindex(node), tree(node), nothing)


"""
root(T::NNTree)

Return the root node of the nearest neighbor search tree.
"""
function root(T::NNTree)
return NNTreeNode(1, T, region(T))
end

function _points(tree_data, data, index, indices, reordered)
if reordered
return (data[idx] for idx in get_leaf_range(tree_data, index))
else
return (data[indices[idx]] for idx in get_leaf_range(tree_data, index))
end
end

function leafpoints(node::NNTreeNode)
# redirect to possibly specialize
T = tree(node)
return _points(T.tree_data, T.data, treeindex(node), T.indices, T.reordered)
end

function leaf_points_indices(node::NNTreeNode)
T = tree(node)
tree_data = T.tree_data
indices = T.indices
return (indices[idx] for idx in get_leaf_range(tree_data, treeindex(node)))
end

# We split the tree such that one of the sub trees has exactly 2^p points
# and such that the left sub tree always has more points.
# This means that we can deterministally (with just some comparisons)
Expand Down Expand Up @@ -145,3 +314,24 @@ function addall(tree::NNTree, index::Int, idx_in_ball::Union{Nothing, Vector{<:I
end
return count
end

# Add all points in this subtree since we have determined
# they are all within the desired range
function addall(node::NNTreeNode, idx_in_ball::Union{Nothing, Vector{<:Integer}})
tree = node.tree
tree_data = tree.tree_data
count = 0
index = node.index
if isleaf(node)
for z in get_leaf_range(tree_data, index)
idx = tree.reordered ? z : tree.indices[z]
count += 1
idx_in_ball !== nothing && push!(idx_in_ball, idx)
end
else
left, right = children(node)
count += addall(left, idx_in_ball)
count += addall(right, idx_in_ball)
end
return count
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const fullmetrics = [metrics; Hamming(); CustomMetric1(); CustomMetric2()]
const trees = [KDTree, BallTree]
const trees_with_brute = [BruteTree; trees]

include("test_walk.jl")
include("test_knn.jl")
include("test_inrange.jl")
include("test_monkey.jl")
Expand Down
Loading