1+
2+ /// A k-d tree implementation supporting the following operations:
3+ ///
4+ /// Main functions:
5+ ///
6+ /// new() -> Create an empty k-d tree
7+ /// build() -> Generate a balance k-d tree from a vector of points
8+ /// insert() -> Add a point to a k-d tree
9+ /// delete() -> Remove a point from a k-d tree
10+ /// contains() -> Search for a point in a k-d tree
11+ /// n_nearest_neighbors -> Search the nearest neighbors of a given point from a k-d tree with their respective distances
12+ /// len() -> Determine the number of points stored in a kd-tree
13+ /// is_empty() -> Determine whether or not there are points in a k-d tree
14+ ///
15+ /// Helper functions:
16+ ///
17+ /// distance() -> Calculate the Euclidean distance between two points
18+ /// min_node() -> Determine the minimum node from a given k-d tree with respect to a given axis
19+ /// min_node_on_axis() -> Determine the minimum node among three nodes on a given axis
20+ ///
21+ /// Check each function's definition for more details
22+ ///
23+ /// TODO: Implement a `range_search` function to return a set of points found within a given boundary
24+
125use num_traits:: { abs, real:: Real , Signed } ;
226use std:: iter:: Sum ;
327
@@ -36,7 +60,6 @@ impl<T: PartialOrd + Copy, const K: usize> Default for KDTree<T, K> {
3660
3761impl < T : PartialOrd + Copy , const K : usize > KDTree < T , K > {
3862 // Create and empty kd-tree
39- // #[must_use]
4063 pub fn new ( ) -> Self {
4164 KDTree {
4265 root : None ,
@@ -49,7 +72,7 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
4972 search_rec ( & self . root , point, 0 )
5073 }
5174
52- // Returns true if successfully delete a point, false otherwise
75+ // Returns true if successfully insert a point, false otherwise
5376 pub fn insert ( & mut self , point : [ T ; K ] ) -> bool {
5477 let inserted: bool = insert_rec ( & mut self . root , point, 0 ) ;
5578 if inserted {
@@ -58,7 +81,7 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
5881 inserted
5982 }
6083
61- // Returns true if successfully delete a point
84+ // Returns true if successfully delete a point, false otherwise
6285 pub fn delete ( & mut self , point : & [ T ; K ] ) -> bool {
6386 let deleted = delete_rec ( & mut self . root , point, 0 ) ;
6487 if deleted {
@@ -78,25 +101,16 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
78101 }
79102
80103 // Returns the number of points in a kd-tree
81- // #[must_use]
82104 pub fn len ( & self ) -> usize {
83105 self . size
84106 }
85107
86- // Returns the depth a kd-tree
87- // #[must_use]
88- pub fn depth ( & self ) -> usize {
89- depth_rec ( & self . root , 0 , 0 )
90- }
91-
92108 // Determine whether there exist points in a kd-tree or not
93- // #[must_use]
94109 pub fn is_empty ( & self ) -> bool {
95110 self . root . is_none ( )
96111 }
97112
98113 // Returns a kd-tree built from a vector points
99- // #[must_use]
100114 pub fn build ( points : Vec < [ T ; K ] > ) -> KDTree < T , K > {
101115 let mut tree: KDTree < T , K > = KDTree :: new ( ) ;
102116 if points. is_empty ( ) {
@@ -109,15 +123,6 @@ impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> {
109123 }
110124 }
111125
112- /// Returns a `KDTree` containing both trees
113- /// Merging two KDTrees by collecting points and rebuilding
114- // #[must_use]
115- pub fn merge ( & mut self , other : & mut Self ) -> Self {
116- let mut points: Vec < [ T ; K ] > = Vec :: new ( ) ;
117- collect_points ( & self . root , & mut points) ;
118- collect_points ( & other. root , & mut points) ;
119- KDTree :: build ( points)
120- }
121126}
122127
123128// Helper functions ............................................................................
@@ -231,37 +236,6 @@ fn build_rec<T: PartialOrd + Copy, const K: usize>(
231236 }
232237}
233238
234- // Returns the depth of the deepest branch of a kd-tree.
235- fn depth_rec < T : PartialOrd + Copy , const K : usize > (
236- kd_tree : & Option < Box < KDNode < T , K > > > ,
237- left_depth : usize ,
238- right_depth : usize ,
239- ) -> usize {
240- if let Some ( kd_node) = kd_tree {
241- match ( & kd_node. left , & kd_node. right ) {
242- ( None , None ) => left_depth. max ( right_depth) ,
243- ( None , Some ( _) ) => depth_rec ( & kd_node. left , left_depth + 1 , right_depth) ,
244- ( Some ( _) , None ) => depth_rec ( & kd_node. right , left_depth, right_depth + 1 ) ,
245- ( Some ( _) , Some ( _) ) => depth_rec ( & kd_node. left , left_depth + 1 , right_depth)
246- . max ( depth_rec ( & kd_node. right , left_depth, right_depth + 1 ) ) ,
247- }
248- } else {
249- left_depth. max ( right_depth)
250- }
251- }
252-
253- // Collect all points from a given `KDTree` into a vector
254- fn collect_points < T : PartialOrd + Copy , const K : usize > (
255- kd_node : & Option < Box < KDNode < T , K > > > ,
256- points : & mut Vec < [ T ; K ] > ,
257- ) {
258- if let Some ( current_node) = kd_node {
259- points. push ( current_node. point ) ;
260- collect_points ( & current_node. left , points) ;
261- collect_points ( & current_node. right , points) ;
262- }
263- }
264-
265239// Calculate the distance between two points
266240fn distance < T , const K : usize > ( point_1 : & [ T ; K ] , point_2 : & [ T ; K ] ) -> T
267241where
@@ -383,88 +357,111 @@ fn n_nearest_neighbors<T, const K: usize>(
383357
384358#[ cfg( test) ]
385359mod test {
360+ /// Tests for the following operations:
361+ ///
362+ /// insert(), contains(), delete(), n_nearest_neighbors(), len(), is_empty()
363+ /// This test uses a 2-Dimensional point
364+ ///
365+ /// TODO: Create a global constant(K for example) to hold the dimension to be tested and adjust each test case to make use of K for points allocation.
366+
386367 use super :: KDTree ;
387368
388369 #[ test]
389370 fn insert ( ) {
390- let mut kd_tree: KDTree < f64 , 2 > = KDTree :: new ( ) ;
391- assert ! ( kd_tree. insert( [ 2.0 , 3.0 ] ) ) ;
392- // Cannot insert the same point again
393- assert ! ( !kd_tree. insert( [ 2.0 , 3.0 ] ) ) ;
394- assert ! ( kd_tree. insert( [ 2.0 , 3.1 ] ) ) ;
371+ let points = ( 0 ..100 ) . map ( |_| {
372+ [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ]
373+ } ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
374+ let mut kd_tree = KDTree :: build ( points) ;
375+ let point = [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ] ;
376+
377+ assert ! ( kd_tree. insert( point) ) ;
378+ // Cannot insert twice
379+ assert ! ( !kd_tree. insert( point) ) ;
395380 }
396381
397382 #[ test]
398383 fn contains ( ) {
399- let points = vec ! [ [ 2.0 , 3.0 ] , [ 5.0 , 4.0 ] , [ 9.0 , 6.0 ] , [ 4.0 , 7.0 ] ] ;
400- let kd_tree = KDTree :: build ( points) ;
401- assert ! ( kd_tree. contains( & [ 5.0 , 4.0 ] ) ) ;
402- assert ! ( !kd_tree. contains( & [ 5.0 , 4.1 ] ) ) ;
384+ let points = ( 0 ..100 ) . map ( |_| {
385+ [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ]
386+ } ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
387+ let mut kd_tree = KDTree :: build ( points) ;
388+ let point = [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ] ;
389+ kd_tree. insert ( point) ;
390+
391+ assert ! ( kd_tree. contains( & point) ) ;
403392 }
404393
405394 #[ test]
406- fn remove ( ) {
407- let points = vec ! [ [ 2.0 , 3.0 ] , [ 5.0 , 4.0 ] , [ 9.0 , 6.0 ] , [ 4.0 , 7.0 ] ] ;
395+ fn delete ( ) {
396+ let points = ( 0 ..100 ) . map ( |_| {
397+ [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ]
398+ } ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
399+ let point = points[ ( rand:: random :: < f64 > ( ) * 100.0 ) . round ( ) as usize ] . clone ( ) ;
408400 let mut kd_tree = KDTree :: build ( points) ;
409- assert ! ( kd_tree. delete( & [ 5.0 , 4.0 ] ) ) ;
410- // Cannot remove twice
411- assert ! ( !kd_tree. delete( & [ 5.0 , 4.0 ] ) ) ;
412- assert ! ( !kd_tree. contains( & [ 5.0 , 4.0 ] ) ) ;
401+
402+ assert ! ( kd_tree. delete( & point) ) ;
403+ // Cannot delete twice
404+ assert ! ( !kd_tree. delete( & point) ) ;
405+ // Ensure point is no longer present in k-d tree
406+ assert ! ( !kd_tree. contains( & point) ) ;
413407 }
414408
415409 #[ test]
416410 fn nearest_neighbors ( ) {
417- let points = vec ! [
418- [ 2.0 , 3.0 ] ,
419- [ 5.0 , 4.0 ] ,
420- [ 9.0 , 6.0 ] ,
421- [ 4.0 , 7.0 ] ,
422- [ 8.0 , 1.0 ] ,
423- [ 7.0 , 2.0 ] ,
411+ // Test with large data set
412+ let points_1 = ( 0 ..1000 ) . map ( |_| {
413+ [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ]
414+ } ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
415+ let kd_tree_1 = KDTree :: build ( points_1) ;
416+ let target = [ 50.0 , 50.0 ] ;
417+ let neighbors_1 = kd_tree_1. nearest_neighbors ( & target, 10 ) ;
418+
419+ // Confirm we have exactly 10 nearest neighbors
420+ assert_eq ! ( neighbors_1. len( ) , 10 ) ;
421+
422+ // `14.14` is the approximate distance between [40.0, 40.0] and [50.0, 50.0] &
423+ // [50.0, 50.0] and [60.0, 60.0]
424+ // so our closest neighbors are expected to be found between the bounding box [40.0, 40.0] - [60.0, 60.0]
425+ // with a distance from [50.0, 50.0] less than or equal 14.14
426+ for neighbor in neighbors_1 {
427+ assert ! ( neighbor. 0 <= 14.14 ) ;
428+ }
429+
430+ // Test with small data set
431+ let points_2 = vec ! [ [ 2.0 , 3.0 ] , [ 5.0 , 4.0 ] , [ 9.0 , 6.0 ] , [ 4.0 , 7.0 ] , [ 8.0 , 1.0 ] ,
432+ [ 7.0 , 2.0 ] ,
424433 ] ;
425- let kd_tree = KDTree :: build ( points) ;
426- // for the point [5.0, 3.0] it's obvious that [5.0, 4.0] is one of its closest neighbor with a distance of 1.0
427- assert ! ( kd_tree
428- . nearest_neighbors( & [ 5.0 , 3.0 ] , 2 )
429- . contains( & ( 1.0 , [ 5.0 , 4.0 ] ) ) ) ;
434+ let kd_tree_2 = KDTree :: build ( points_2) ;
435+ let neighbors_2 = kd_tree_2. nearest_neighbors ( & [ 6.0 , 3.0 ] , 3 ) ;
436+ let expected_neighbors = vec ! [ [ 7.0 , 2.0 ] , [ 5.0 , 4.0 ] , [ 8.0 , 1.0 ] ] ;
437+ let neighbors = neighbors_2. iter ( ) . map ( |a| a. 1 ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
438+
439+ // Confirm we have exactly 10 nearest neighbors
440+ assert_eq ! ( neighbors. len( ) , 3 ) ;
441+
442+ // With a small set of data, we can manually calculate our 3 nearest neighbors
443+ // and compare with those obtained from our method
444+ assert_eq ! ( neighbors, expected_neighbors) ;
430445 }
431446
432447 #[ test]
433448 fn is_empty ( ) {
434449 let mut kd_tree = KDTree :: new ( ) ;
450+
435451 assert ! ( kd_tree. is_empty( ) ) ;
452+
436453 kd_tree. insert ( [ 1.5 , 3.0 ] ) ;
437- assert ! ( !kd_tree. is_empty( ) ) ;
438- }
439454
440- #[ test]
441- fn len_and_depth ( ) {
442- let points = vec ! [
443- [ 2.0 , 3.0 ] ,
444- [ 5.0 , 4.0 ] ,
445- [ 9.0 , 6.0 ] ,
446- [ 4.0 , 7.0 ] ,
447- [ 8.0 , 1.0 ] ,
448- [ 7.0 , 2.0 ] ,
449- ] ;
450- let size = points. len ( ) ;
451- let tree = KDTree :: build ( points) ;
452- assert_eq ! ( tree. len( ) , size) ;
453- assert_eq ! ( tree. depth( ) , 2 ) ;
455+ assert ! ( !kd_tree. is_empty( ) ) ;
454456 }
455457
456458 #[ test]
457- fn merge ( ) {
458- let points_1 = vec ! [ [ 2.0 , 3.0 ] , [ 5.0 , 4.0 ] , [ 9.0 , 6.0 ] ] ;
459- let points_2 = vec ! [ [ 4.0 , 7.0 ] , [ 8.0 , 1.0 ] , [ 7.0 , 2.0 ] ] ;
460-
461- let mut kd_tree_1 = KDTree :: build ( points_1) ;
462- let mut kd_tree_2 = KDTree :: build ( points_2) ;
463-
464- let kd_tree_3 = kd_tree_1. merge ( & mut kd_tree_2) ;
465-
466- // Making sure the resulted kd-tree contains points from both kd-trees
467- assert ! ( kd_tree_3. contains( & [ 9.0 , 6.0 ] ) ) ;
468- assert ! ( kd_tree_3. contains( & [ 8.0 , 1.0 ] ) ) ;
459+ fn len ( ) {
460+ let points = ( 0 ..1000 ) . map ( |_| {
461+ [ ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 , ( rand:: random :: < f64 > ( ) * 1000.0 ) . round ( ) / 10.0 ]
462+ } ) . collect :: < Vec < [ f64 ; 2 ] > > ( ) ;
463+ let kd_tree = KDTree :: build ( points) ;
464+
465+ assert_eq ! ( kd_tree. len( ) , 1000 ) ;
469466 }
470467}
0 commit comments