Skip to content

Commit d8f9bd2

Browse files
committed
kd-tree implementation in Rust
1 parent be27f2c commit d8f9bd2

File tree

1 file changed

+386
-0
lines changed

1 file changed

+386
-0
lines changed

src/data_structures/kd_tree.rs

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
use std::iter::Sum;
2+
use num_traits::{abs, real::Real, Signed};
3+
4+
#[derive(Debug, Clone, PartialEq)]
5+
struct KDNode<T: PartialOrd + Copy, const K: usize> {
6+
point: [T; K],
7+
left: Option<Box<KDNode<T, K>>>,
8+
right: Option<Box<KDNode<T, K>>>
9+
}
10+
11+
impl<T: PartialOrd + Copy, const K: usize> KDNode<T, K> {
12+
fn new(point: [T; K]) -> Self {
13+
KDNode {
14+
point,
15+
left: None,
16+
right: None
17+
}
18+
}
19+
20+
}
21+
22+
#[derive(Debug)]
23+
struct KDTree<T: PartialOrd + Copy, const K: usize> {
24+
root: Option<Box<KDNode<T, K>>>,
25+
size: usize
26+
}
27+
28+
impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
29+
// Create and empty kd-tree
30+
pub fn new() -> Self {
31+
KDTree {
32+
root: None,
33+
size: 0
34+
}
35+
}
36+
37+
// Returns true if point found, false otherwise
38+
pub fn search(&self, point: &[T; K]) -> bool {
39+
search_rec(&self.root, point, 0)
40+
}
41+
42+
// Returns true if successfully delete a point, false otherwise
43+
pub fn insert(&mut self, point: [T; K]) -> bool {
44+
let inserted: bool = insert_rec(&mut self.root, point, 0);
45+
if inserted {
46+
self.size += 1;
47+
}
48+
inserted
49+
}
50+
51+
// Returns true if successfully delete a point
52+
pub fn delete(&mut self, point: &[T; K]) -> bool {
53+
let deleted = delete_rec(&mut self.root, point, 0);
54+
if deleted {
55+
self.size -= 1;
56+
}
57+
deleted
58+
}
59+
60+
// Returns the nearest neighbors of a given point with their respective disatances
61+
pub fn nearest_neighbors(&self, point: &[T; K], n: usize) -> Vec<(T, [T; K])>
62+
where
63+
T: Sum + Signed + Real
64+
{
65+
let mut neighbors: Vec<(T, [T; K])> = Vec::new();
66+
n_nearest_neighbors(&self.root, point, n, 0, &mut neighbors);
67+
neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
68+
neighbors
69+
}
70+
71+
// Returns the number of points in a kd-tree
72+
pub fn len(&self) -> usize {
73+
self.size
74+
}
75+
76+
// Returns the depth a kd-tree
77+
pub fn depth(&self) -> usize {
78+
depth_rec(&self.root, 0, 0)
79+
}
80+
81+
// Determine whether there exist points in a kd-tree or not
82+
pub fn is_empty(&self) -> bool {
83+
self.root.is_none()
84+
}
85+
86+
// Returns a kd-tree built from a vector points
87+
pub fn build(points: Vec<[T; K]>) -> KDTree<T, K> {
88+
let mut tree = KDTree::new();
89+
if points.is_empty() {
90+
tree
91+
} else {
92+
tree.size = points.len();
93+
tree.root = build_rec(points, 0);
94+
95+
tree
96+
}
97+
}
98+
99+
// Merging two KDTrees by collecting points and rebuilding
100+
pub fn merge(&mut self, other: &mut Self) -> Self {
101+
let mut points: Vec<[T; K]> = Vec::new();
102+
collect_points(&self.root, &mut points);
103+
collect_points(&other.root, &mut points);
104+
KDTree::build(points)
105+
}
106+
}
107+
108+
// Helper functions ............................................................................
109+
110+
// Recursively insert a point in a kd-tree
111+
fn insert_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &mut Option<Box<KDNode<T, K>>>, point: [T; K], depth: usize) -> bool {
112+
if let Some(ref mut kd_node) = kd_tree {
113+
let axis: usize = depth % K;
114+
if point[axis] < kd_node.point[axis] {
115+
insert_rec(&mut kd_node.left, point, depth + 1)
116+
} else if point == kd_node.point {
117+
false
118+
} else {
119+
insert_rec(&mut kd_node.right, point, depth + 1)
120+
}
121+
} else {
122+
*kd_tree = Some(Box::new(KDNode::new(point)));
123+
true
124+
}
125+
}
126+
127+
// Recursively search for a given point in a kd-tree
128+
fn search_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, point: &[T; K], depth: usize) -> bool {
129+
if let Some(kd_node) = kd_tree {
130+
if point == &kd_node.point {
131+
true
132+
} else {
133+
let axis: usize = depth % K;
134+
if point[axis] < kd_node.point[axis] {
135+
search_rec(&kd_node.left, point, depth + 1)
136+
} else {
137+
search_rec(&kd_node.right, point, depth + 1)
138+
}
139+
}
140+
} else {
141+
false
142+
}
143+
}
144+
145+
// Recursively delete a point from a kd-tree
146+
fn delete_rec<T: PartialOrd + Copy, const K: usize>(kd_node: &mut Option<Box<KDNode<T, K>>>, point: &[T; K], depth: usize) -> bool {
147+
if let Some(current_node) = kd_node {
148+
let axis: usize = depth % K;
149+
if current_node.point == *point {
150+
if current_node.right.is_some() {
151+
// safe to use `unwrap()` since we know for sure there exist a node
152+
let min_point = min_node(current_node.right.as_ref(), axis, 0).unwrap().point;
153+
154+
current_node.point = min_point;
155+
delete_rec(&mut current_node.right, &current_node.point, depth + 1)
156+
} else if current_node.left.is_some() {
157+
let min_point: [T; K] = min_node(current_node.left.as_ref(), axis, 0).unwrap().point;
158+
159+
current_node.point = min_point;
160+
current_node.right = current_node.left.take();
161+
delete_rec(&mut current_node.right, &current_node.point, depth + 1)
162+
}else {
163+
*kd_node = None;
164+
true
165+
}
166+
} else if point[axis].lt(&current_node.point[axis]) {
167+
delete_rec(&mut current_node.left, point, depth + 1)
168+
} else {
169+
delete_rec(&mut current_node.right, point, depth + 1)
170+
}
171+
} else {
172+
false
173+
}
174+
}
175+
176+
/// Recursively build a kd-tree from a vector of points by taking the median of a sorted
177+
/// vector of points as the root to maintain a balance kd-tree
178+
fn build_rec<T: PartialOrd + Copy, const K: usize>(mut points: Vec<[T; K]>, depth: usize) -> Option<Box<KDNode<T, K>>> {
179+
if points.is_empty() {
180+
None
181+
} else {
182+
let axis = depth % K;
183+
points.sort_by(|a, b| a[axis].partial_cmp(&b[axis]).unwrap_or(std::cmp::Ordering::Equal));
184+
185+
let median: usize = points.len() / 2;
186+
let mut node: KDNode<T, K> = KDNode::new(points[median]);
187+
188+
node.left = build_rec(points[..median].to_vec(), depth + 1);
189+
node.right = build_rec(points[median + 1..].to_vec(), depth + 1);
190+
191+
Some(Box::new(node))
192+
}
193+
}
194+
195+
// Returns the depth of the deepest branch of a kd-tree.
196+
fn depth_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, left_depth: usize, right_depth: usize) -> usize {
197+
if let Some(kd_node) = kd_tree {
198+
match (&kd_node.left, &kd_node.right) {
199+
(None, None) => left_depth.max(right_depth),
200+
(None, Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth),
201+
(Some(_), None) => depth_rec(&kd_node.right, left_depth, right_depth + 1),
202+
(Some(_), Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth).max(depth_rec(&kd_node.right, left_depth, right_depth + 1)),
203+
}
204+
} else {
205+
left_depth.max(right_depth)
206+
}
207+
}
208+
209+
// Collect all points from a given KDTree into a vector
210+
fn collect_points<T: PartialOrd + Copy, const K: usize>(kd_node: &Option<Box<KDNode<T, K>>>, points: &mut Vec<[T; K]>) {
211+
if let Some(current_node) = kd_node {
212+
points.push(current_node.point);
213+
collect_points(&current_node.left, points);
214+
collect_points(&current_node.right, points);
215+
}
216+
}
217+
218+
// Calculate the distance between two points
219+
fn distance<T, const K: usize>(point_1: &[T; K], point_2: &[T; K]) -> T
220+
where
221+
T: PartialOrd + Copy + Sum + Real
222+
{
223+
point_1
224+
.iter()
225+
.zip(point_2.iter())
226+
.map(|(&a, &b)| {
227+
let diff: T = a - b;
228+
diff * diff
229+
})
230+
.sum::<T>()
231+
.sqrt()
232+
}
233+
234+
// Returns the minimum nodes among three kd-nodes on a given axis
235+
fn min_node_on_axis<'a, T: PartialOrd + Copy, const K: usize>(kd_node: &'a KDNode<T, K>, left: Option<&'a KDNode<T, K>>, right: Option<&'a KDNode<T, K>>, axis: usize) -> &'a KDNode<T, K> {
236+
let mut current_min_node = kd_node;
237+
if let Some(left_node) = left {
238+
if left_node.point[axis].lt(&current_min_node.point[axis]) {
239+
current_min_node = left_node;
240+
}
241+
}
242+
if let Some(right_node) = right {
243+
if right_node.point[axis].lt(&current_min_node.point[axis]) {
244+
current_min_node = right_node;
245+
}
246+
}
247+
current_min_node
248+
}
249+
250+
// Returns the minimum node of a kd-tree with respect to an axis
251+
fn min_node<T: PartialOrd + Copy, const K: usize>(kd_node: Option<&Box<KDNode<T, K>>>, axis: usize, depth: usize) -> Option<&KDNode<T, K>> {
252+
if let Some(current_node) = kd_node {
253+
let current_axis = depth % K;
254+
if current_axis == axis {
255+
if current_node.left.is_some() {
256+
min_node(current_node.left.as_ref(), axis, depth + 1)
257+
} else {
258+
Some(current_node)
259+
}
260+
} else {
261+
let (left_min, right_min): (Option<&KDNode<T, K>>, Option<&KDNode<T, K>>) = (
262+
min_node(current_node.left.as_ref(), axis, depth + 1),
263+
min_node(current_node.right.as_ref(), axis, depth + 1)
264+
);
265+
Some(min_node_on_axis(current_node, left_min, right_min, axis))
266+
}
267+
} else {
268+
None
269+
}
270+
}
271+
272+
// Find the nearest neighbors of a given point. The number neighbors is determine by the variable `n`.
273+
fn n_nearest_neighbors<T, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, point: &[T; K], n: usize, depth: usize, neighbors: &mut Vec<(T, [T; K])>)
274+
where
275+
T: PartialOrd + Copy + Sum + Real + Signed
276+
{
277+
if let Some(kd_node) = kd_tree {
278+
let distance: T = distance(&kd_node.point, point);
279+
if neighbors.len() < n {
280+
neighbors.push((distance, kd_node.point));
281+
} else {
282+
// safe to call unwrap() since we know our neighbors is ont empty in this scope
283+
let max_distance = neighbors.iter().max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()).unwrap().0;
284+
if distance < max_distance {
285+
if let Some(pos) = neighbors.iter().position(|x| x.0 == max_distance) {
286+
neighbors[pos] = (distance, kd_node.point);
287+
}
288+
}
289+
}
290+
291+
let axis: usize = depth % K;
292+
let target_axis: T = point[axis];
293+
let split_axis: T = kd_node.point[axis];
294+
295+
let (look_first, look_second) = if target_axis < split_axis {
296+
(&kd_node.left, &kd_node.right)
297+
} else {
298+
(&kd_node.right, &kd_node.left)
299+
};
300+
301+
if look_first.is_some() {
302+
n_nearest_neighbors(&look_first, point, n, depth + 1, neighbors);
303+
}
304+
305+
// Check if it's necessary to look on the other branch by computing the distance between our current point with the nearest point on the other branch
306+
if look_second.is_some() {
307+
let max_distance = neighbors.iter().max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()).unwrap().0;
308+
if neighbors.len() < n || abs(target_axis - split_axis) < max_distance {
309+
n_nearest_neighbors(&look_second, point, n, depth + 1, neighbors);
310+
}
311+
}
312+
}
313+
}
314+
315+
316+
#[cfg(test)]
317+
mod test {
318+
use super::KDTree;
319+
320+
#[test]
321+
fn insert() {
322+
let mut kd_tree: KDTree<f64, 2> = KDTree::new();
323+
assert_eq!(kd_tree.insert([2.0, 3.0]), true);
324+
// Cannot insert the same point again
325+
assert_eq!(kd_tree.insert([2.0, 3.0]), false);
326+
assert_eq!(kd_tree.insert([2.0, 3.1]), true);
327+
}
328+
329+
#[test]
330+
fn contains() {
331+
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]];
332+
let kd_tree = KDTree::build(points);
333+
assert_eq!(kd_tree.search(&[5.0, 4.0]), true);
334+
assert_eq!(kd_tree.search(&[5.0, 4.1]), false);
335+
}
336+
337+
#[test]
338+
fn remove() {
339+
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]];
340+
let mut kd_tree = KDTree::build(points);
341+
assert_eq!(kd_tree.delete(&[5.0, 4.0]), true);
342+
// Cannot remove twice
343+
assert_eq!(kd_tree.delete(&[5.0, 4.0]), false);
344+
assert_eq!(kd_tree.search(&[5.0, 4.0]), false);
345+
}
346+
347+
#[test]
348+
fn nearest_neighbors() {
349+
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
350+
let kd_tree = KDTree::build(points);
351+
// for the point [5.0, 3.0] it's obvious that [5.0, 4.0] is one of its closest neighbor with a distance of 1.0
352+
assert!(kd_tree.nearest_neighbors(&[5.0, 3.0], 2).contains(&(1.0, [5.0, 4.0])));
353+
}
354+
355+
#[test]
356+
fn is_empty() {
357+
let mut kd_tree = KDTree::new();
358+
assert_eq!(kd_tree.is_empty(), true);
359+
kd_tree.insert([1.5, 3.0]);
360+
assert_eq!(kd_tree.is_empty(), false);
361+
}
362+
363+
#[test]
364+
fn len_and_depth() {
365+
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
366+
let size = points.len();
367+
let tree = KDTree::build(points);
368+
assert_eq!(tree.len(), size);
369+
assert_eq!(tree.depth(), 2);
370+
}
371+
372+
#[test]
373+
fn merge() {
374+
let points_1 = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0]];
375+
let points_2 = vec![[4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
376+
377+
let mut kd_tree_1 = KDTree::build(points_1);
378+
let mut kd_tree_2 = KDTree::build(points_2);
379+
380+
let kd_tree_3 = kd_tree_1.merge(&mut kd_tree_2);
381+
382+
// Making sure the resulted kd-tree contains points from both kd-trees
383+
assert!(kd_tree_3.search(&[9.0, 6.0]));
384+
assert!(kd_tree_3.search(&[8.0, 1.0]));
385+
}
386+
}

0 commit comments

Comments
 (0)