Skip to content

Commit 89e1063

Browse files
committed
Update kd_tree.rs
1 parent 1ee041f commit 89e1063

File tree

1 file changed

+104
-107
lines changed

1 file changed

+104
-107
lines changed

src/data_structures/kd_tree.rs

Lines changed: 104 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
2+
/// A k-d tree implementation supporting the following operations:
3+
///
4+
/// Main functions:
5+
///
6+
/// new() -> Create an empty k-d tree
7+
/// build() -> Generate a balance k-d tree from a vector of points
8+
/// insert() -> Add a point to a k-d tree
9+
/// delete() -> Remove a point from a k-d tree
10+
/// contains() -> Search for a point in a k-d tree
11+
/// n_nearest_neighbors -> Search the nearest neighbors of a given point from a k-d tree with their respective distances
12+
/// len() -> Determine the number of points stored in a kd-tree
13+
/// is_empty() -> Determine whether or not there are points in a k-d tree
14+
///
15+
/// Helper functions:
16+
///
17+
/// distance() -> Calculate the Euclidean distance between two points
18+
/// min_node() -> Determine the minimum node from a given k-d tree with respect to a given axis
19+
/// min_node_on_axis() -> Determine the minimum node among three nodes on a given axis
20+
///
21+
/// Check each function's definition for more details
22+
///
23+
/// TODO: Implement a `range_search` function to return a set of points found within a given boundary
24+
125
use num_traits::{abs, real::Real, Signed};
226
use std::iter::Sum;
327

@@ -36,7 +60,6 @@ impl<T: PartialOrd + Copy, const K: usize> Default for KDTree<T, K> {
3660

3761
impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
3862
// Create and empty kd-tree
39-
// #[must_use]
4063
pub fn new() -> Self {
4164
KDTree {
4265
root: None,
@@ -49,7 +72,7 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
4972
search_rec(&self.root, point, 0)
5073
}
5174

52-
// Returns true if successfully delete a point, false otherwise
75+
// Returns true if successfully insert a point, false otherwise
5376
pub fn insert(&mut self, point: [T; K]) -> bool {
5477
let inserted: bool = insert_rec(&mut self.root, point, 0);
5578
if inserted {
@@ -58,7 +81,7 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
5881
inserted
5982
}
6083

61-
// Returns true if successfully delete a point
84+
// Returns true if successfully delete a point, false otherwise
6285
pub fn delete(&mut self, point: &[T; K]) -> bool {
6386
let deleted = delete_rec(&mut self.root, point, 0);
6487
if deleted {
@@ -78,25 +101,16 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
78101
}
79102

80103
// Returns the number of points in a kd-tree
81-
// #[must_use]
82104
pub fn len(&self) -> usize {
83105
self.size
84106
}
85107

86-
// Returns the depth a kd-tree
87-
// #[must_use]
88-
pub fn depth(&self) -> usize {
89-
depth_rec(&self.root, 0, 0)
90-
}
91-
92108
// Determine whether there exist points in a kd-tree or not
93-
// #[must_use]
94109
pub fn is_empty(&self) -> bool {
95110
self.root.is_none()
96111
}
97112

98113
// Returns a kd-tree built from a vector points
99-
// #[must_use]
100114
pub fn build(points: Vec<[T; K]>) -> KDTree<T, K> {
101115
let mut tree: KDTree<T, K> = KDTree::new();
102116
if points.is_empty() {
@@ -109,15 +123,6 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
109123
}
110124
}
111125

112-
/// Returns a `KDTree` containing both trees
113-
/// Merging two KDTrees by collecting points and rebuilding
114-
// #[must_use]
115-
pub fn merge(&mut self, other: &mut Self) -> Self {
116-
let mut points: Vec<[T; K]> = Vec::new();
117-
collect_points(&self.root, &mut points);
118-
collect_points(&other.root, &mut points);
119-
KDTree::build(points)
120-
}
121126
}
122127

123128
// Helper functions ............................................................................
@@ -231,37 +236,6 @@ fn build_rec<T: PartialOrd + Copy, const K: usize>(
231236
}
232237
}
233238

234-
// Returns the depth of the deepest branch of a kd-tree.
235-
fn depth_rec<T: PartialOrd + Copy, const K: usize>(
236-
kd_tree: &Option<Box<KDNode<T, K>>>,
237-
left_depth: usize,
238-
right_depth: usize,
239-
) -> usize {
240-
if let Some(kd_node) = kd_tree {
241-
match (&kd_node.left, &kd_node.right) {
242-
(None, None) => left_depth.max(right_depth),
243-
(None, Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth),
244-
(Some(_), None) => depth_rec(&kd_node.right, left_depth, right_depth + 1),
245-
(Some(_), Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth)
246-
.max(depth_rec(&kd_node.right, left_depth, right_depth + 1)),
247-
}
248-
} else {
249-
left_depth.max(right_depth)
250-
}
251-
}
252-
253-
// Collect all points from a given `KDTree` into a vector
254-
fn collect_points<T: PartialOrd + Copy, const K: usize>(
255-
kd_node: &Option<Box<KDNode<T, K>>>,
256-
points: &mut Vec<[T; K]>,
257-
) {
258-
if let Some(current_node) = kd_node {
259-
points.push(current_node.point);
260-
collect_points(&current_node.left, points);
261-
collect_points(&current_node.right, points);
262-
}
263-
}
264-
265239
// Calculate the distance between two points
266240
fn distance<T, const K: usize>(point_1: &[T; K], point_2: &[T; K]) -> T
267241
where
@@ -383,88 +357,111 @@ fn n_nearest_neighbors<T, const K: usize>(
383357

384358
#[cfg(test)]
385359
mod test {
360+
/// Tests for the following operations:
361+
///
362+
/// insert(), contains(), delete(), n_nearest_neighbors(), len(), is_empty()
363+
/// This test uses a 2-Dimensional point
364+
///
365+
/// TODO: Create a global constant(K for example) to hold the dimension to be tested and adjust each test case to make use of K for points allocation.
366+
386367
use super::KDTree;
387368

388369
#[test]
389370
fn insert() {
390-
let mut kd_tree: KDTree<f64, 2> = KDTree::new();
391-
assert!(kd_tree.insert([2.0, 3.0]));
392-
// Cannot insert the same point again
393-
assert!(!kd_tree.insert([2.0, 3.0]));
394-
assert!(kd_tree.insert([2.0, 3.1]));
371+
let points = (0..100).map(|_| {
372+
[(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0]
373+
}).collect::<Vec<[f64; 2]>>();
374+
let mut kd_tree = KDTree::build(points);
375+
let point = [(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0];
376+
377+
assert!(kd_tree.insert(point));
378+
// Cannot insert twice
379+
assert!(!kd_tree.insert(point));
395380
}
396381

397382
#[test]
398383
fn contains() {
399-
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]];
400-
let kd_tree = KDTree::build(points);
401-
assert!(kd_tree.contains(&[5.0, 4.0]));
402-
assert!(!kd_tree.contains(&[5.0, 4.1]));
384+
let points = (0..100).map(|_| {
385+
[(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0]
386+
}).collect::<Vec<[f64; 2]>>();
387+
let mut kd_tree = KDTree::build(points);
388+
let point = [(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0];
389+
kd_tree.insert(point);
390+
391+
assert!(kd_tree.contains(&point));
403392
}
404393

405394
#[test]
406-
fn remove() {
407-
let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]];
395+
fn delete() {
396+
let points = (0..100).map(|_| {
397+
[(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0]
398+
}).collect::<Vec<[f64; 2]>>();
399+
let point = points[(rand::random::<f64>() * 100.0).round() as usize].clone();
408400
let mut kd_tree = KDTree::build(points);
409-
assert!(kd_tree.delete(&[5.0, 4.0]));
410-
// Cannot remove twice
411-
assert!(!kd_tree.delete(&[5.0, 4.0]));
412-
assert!(!kd_tree.contains(&[5.0, 4.0]));
401+
402+
assert!(kd_tree.delete(&point));
403+
// Cannot delete twice
404+
assert!(!kd_tree.delete(&point));
405+
// Ensure point is no longer present in k-d tree
406+
assert!(!kd_tree.contains(&point));
413407
}
414408

415409
#[test]
416410
fn nearest_neighbors() {
417-
let points = vec![
418-
[2.0, 3.0],
419-
[5.0, 4.0],
420-
[9.0, 6.0],
421-
[4.0, 7.0],
422-
[8.0, 1.0],
423-
[7.0, 2.0],
411+
// Test with large data set
412+
let points_1 = (0..1000).map(|_| {
413+
[(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0]
414+
}).collect::<Vec<[f64; 2]>>();
415+
let kd_tree_1 = KDTree::build(points_1);
416+
let target = [50.0, 50.0];
417+
let neighbors_1 = kd_tree_1.nearest_neighbors(&target, 10);
418+
419+
// Confirm we have exactly 10 nearest neighbors
420+
assert_eq!(neighbors_1.len(), 10);
421+
422+
// `14.14` is the approximate distance between [40.0, 40.0] and [50.0, 50.0] &
423+
// [50.0, 50.0] and [60.0, 60.0]
424+
// so our closest neighbors are expected to be found between the bounding box [40.0, 40.0] - [60.0, 60.0]
425+
// with a distance from [50.0, 50.0] less than or equal 14.14
426+
for neighbor in neighbors_1 {
427+
assert!(neighbor.0 <= 14.14);
428+
}
429+
430+
// Test with small data set
431+
let points_2 = vec![[2.0, 3.0],[5.0, 4.0],[9.0, 6.0],[4.0, 7.0],[8.0, 1.0],
432+
[7.0, 2.0],
424433
];
425-
let kd_tree = KDTree::build(points);
426-
// for the point [5.0, 3.0] it's obvious that [5.0, 4.0] is one of its closest neighbor with a distance of 1.0
427-
assert!(kd_tree
428-
.nearest_neighbors(&[5.0, 3.0], 2)
429-
.contains(&(1.0, [5.0, 4.0])));
434+
let kd_tree_2 = KDTree::build(points_2);
435+
let neighbors_2 = kd_tree_2.nearest_neighbors(&[6.0, 3.0], 3);
436+
let expected_neighbors = vec![[7.0, 2.0], [5.0, 4.0], [8.0, 1.0]];
437+
let neighbors = neighbors_2.iter().map(|a| a.1).collect::<Vec<[f64; 2]>>();
438+
439+
// Confirm we have exactly 10 nearest neighbors
440+
assert_eq!(neighbors.len(), 3);
441+
442+
// With a small set of data, we can manually calculate our 3 nearest neighbors
443+
// and compare with those obtained from our method
444+
assert_eq!(neighbors, expected_neighbors);
430445
}
431446

432447
#[test]
433448
fn is_empty() {
434449
let mut kd_tree = KDTree::new();
450+
435451
assert!(kd_tree.is_empty());
452+
436453
kd_tree.insert([1.5, 3.0]);
437-
assert!(!kd_tree.is_empty());
438-
}
439454

440-
#[test]
441-
fn len_and_depth() {
442-
let points = vec![
443-
[2.0, 3.0],
444-
[5.0, 4.0],
445-
[9.0, 6.0],
446-
[4.0, 7.0],
447-
[8.0, 1.0],
448-
[7.0, 2.0],
449-
];
450-
let size = points.len();
451-
let tree = KDTree::build(points);
452-
assert_eq!(tree.len(), size);
453-
assert_eq!(tree.depth(), 2);
455+
assert!(!kd_tree.is_empty());
454456
}
455457

456458
#[test]
457-
fn merge() {
458-
let points_1 = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0]];
459-
let points_2 = vec![[4.0, 7.0], [8.0, 1.0], [7.0, 2.0]];
460-
461-
let mut kd_tree_1 = KDTree::build(points_1);
462-
let mut kd_tree_2 = KDTree::build(points_2);
463-
464-
let kd_tree_3 = kd_tree_1.merge(&mut kd_tree_2);
465-
466-
// Making sure the resulted kd-tree contains points from both kd-trees
467-
assert!(kd_tree_3.contains(&[9.0, 6.0]));
468-
assert!(kd_tree_3.contains(&[8.0, 1.0]));
459+
fn len() {
460+
let points = (0..1000).map(|_| {
461+
[(rand::random::<f64>() * 1000.0).round() / 10.0, (rand::random::<f64>() * 1000.0).round() / 10.0]
462+
}).collect::<Vec<[f64; 2]>>();
463+
let kd_tree = KDTree::build(points);
464+
465+
assert_eq!(kd_tree.len(), 1000);
469466
}
470467
}

0 commit comments

Comments
 (0)