|
| 1 | +use std::iter::Sum; |
| 2 | +use num_traits::{abs, real::Real, Signed}; |
| 3 | + |
| 4 | +#[derive(Debug, Clone, PartialEq)] |
| 5 | +struct KDNode<T: PartialOrd + Copy, const K: usize> { |
| 6 | + point: [T; K], |
| 7 | + left: Option<Box<KDNode<T, K>>>, |
| 8 | + right: Option<Box<KDNode<T, K>>> |
| 9 | +} |
| 10 | + |
| 11 | +impl<T: PartialOrd + Copy, const K: usize> KDNode<T, K> { |
| 12 | + fn new(point: [T; K]) -> Self { |
| 13 | + KDNode { |
| 14 | + point, |
| 15 | + left: None, |
| 16 | + right: None |
| 17 | + } |
| 18 | + } |
| 19 | + |
| 20 | +} |
| 21 | + |
| 22 | +#[derive(Debug)] |
| 23 | +struct KDTree<T: PartialOrd + Copy, const K: usize> { |
| 24 | + root: Option<Box<KDNode<T, K>>>, |
| 25 | + size: usize |
| 26 | +} |
| 27 | + |
| 28 | +impl<T: PartialOrd + Copy, const K: usize> KDTree<T, K> { |
| 29 | + // Create and empty kd-tree |
| 30 | + pub fn new() -> Self { |
| 31 | + KDTree { |
| 32 | + root: None, |
| 33 | + size: 0 |
| 34 | + } |
| 35 | + } |
| 36 | + |
| 37 | + // Returns true if point found, false otherwise |
| 38 | + pub fn search(&self, point: &[T; K]) -> bool { |
| 39 | + search_rec(&self.root, point, 0) |
| 40 | + } |
| 41 | + |
| 42 | + // Returns true if successfully delete a point, false otherwise |
| 43 | + pub fn insert(&mut self, point: [T; K]) -> bool { |
| 44 | + let inserted: bool = insert_rec(&mut self.root, point, 0); |
| 45 | + if inserted { |
| 46 | + self.size += 1; |
| 47 | + } |
| 48 | + inserted |
| 49 | + } |
| 50 | + |
| 51 | + // Returns true if successfully delete a point |
| 52 | + pub fn delete(&mut self, point: &[T; K]) -> bool { |
| 53 | + let deleted = delete_rec(&mut self.root, point, 0); |
| 54 | + if deleted { |
| 55 | + self.size -= 1; |
| 56 | + } |
| 57 | + deleted |
| 58 | + } |
| 59 | + |
| 60 | + // Returns the nearest neighbors of a given point with their respective disatances |
| 61 | + pub fn nearest_neighbors(&self, point: &[T; K], n: usize) -> Vec<(T, [T; K])> |
| 62 | + where |
| 63 | + T: Sum + Signed + Real |
| 64 | + { |
| 65 | + let mut neighbors: Vec<(T, [T; K])> = Vec::new(); |
| 66 | + n_nearest_neighbors(&self.root, point, n, 0, &mut neighbors); |
| 67 | + neighbors.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); |
| 68 | + neighbors |
| 69 | + } |
| 70 | + |
| 71 | + // Returns the number of points in a kd-tree |
| 72 | + pub fn len(&self) -> usize { |
| 73 | + self.size |
| 74 | + } |
| 75 | + |
| 76 | + // Returns the depth a kd-tree |
| 77 | + pub fn depth(&self) -> usize { |
| 78 | + depth_rec(&self.root, 0, 0) |
| 79 | + } |
| 80 | + |
| 81 | + // Determine whether there exist points in a kd-tree or not |
| 82 | + pub fn is_empty(&self) -> bool { |
| 83 | + self.root.is_none() |
| 84 | + } |
| 85 | + |
| 86 | + // Returns a kd-tree built from a vector points |
| 87 | + pub fn build(points: Vec<[T; K]>) -> KDTree<T, K> { |
| 88 | + let mut tree = KDTree::new(); |
| 89 | + if points.is_empty() { |
| 90 | + tree |
| 91 | + } else { |
| 92 | + tree.size = points.len(); |
| 93 | + tree.root = build_rec(points, 0); |
| 94 | + |
| 95 | + tree |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + // Merging two KDTrees by collecting points and rebuilding |
| 100 | + pub fn merge(&mut self, other: &mut Self) -> Self { |
| 101 | + let mut points: Vec<[T; K]> = Vec::new(); |
| 102 | + collect_points(&self.root, &mut points); |
| 103 | + collect_points(&other.root, &mut points); |
| 104 | + KDTree::build(points) |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +// Helper functions ............................................................................ |
| 109 | + |
| 110 | +// Recursively insert a point in a kd-tree |
| 111 | +fn insert_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &mut Option<Box<KDNode<T, K>>>, point: [T; K], depth: usize) -> bool { |
| 112 | + if let Some(ref mut kd_node) = kd_tree { |
| 113 | + let axis: usize = depth % K; |
| 114 | + if point[axis] < kd_node.point[axis] { |
| 115 | + insert_rec(&mut kd_node.left, point, depth + 1) |
| 116 | + } else if point == kd_node.point { |
| 117 | + false |
| 118 | + } else { |
| 119 | + insert_rec(&mut kd_node.right, point, depth + 1) |
| 120 | + } |
| 121 | + } else { |
| 122 | + *kd_tree = Some(Box::new(KDNode::new(point))); |
| 123 | + true |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +// Recursively search for a given point in a kd-tree |
| 128 | +fn search_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, point: &[T; K], depth: usize) -> bool { |
| 129 | + if let Some(kd_node) = kd_tree { |
| 130 | + if point == &kd_node.point { |
| 131 | + true |
| 132 | + } else { |
| 133 | + let axis: usize = depth % K; |
| 134 | + if point[axis] < kd_node.point[axis] { |
| 135 | + search_rec(&kd_node.left, point, depth + 1) |
| 136 | + } else { |
| 137 | + search_rec(&kd_node.right, point, depth + 1) |
| 138 | + } |
| 139 | + } |
| 140 | + } else { |
| 141 | + false |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +// Recursively delete a point from a kd-tree |
| 146 | +fn delete_rec<T: PartialOrd + Copy, const K: usize>(kd_node: &mut Option<Box<KDNode<T, K>>>, point: &[T; K], depth: usize) -> bool { |
| 147 | + if let Some(current_node) = kd_node { |
| 148 | + let axis: usize = depth % K; |
| 149 | + if current_node.point == *point { |
| 150 | + if current_node.right.is_some() { |
| 151 | + // safe to use `unwrap()` since we know for sure there exist a node |
| 152 | + let min_point = min_node(current_node.right.as_ref(), axis, 0).unwrap().point; |
| 153 | + |
| 154 | + current_node.point = min_point; |
| 155 | + delete_rec(&mut current_node.right, ¤t_node.point, depth + 1) |
| 156 | + } else if current_node.left.is_some() { |
| 157 | + let min_point: [T; K] = min_node(current_node.left.as_ref(), axis, 0).unwrap().point; |
| 158 | + |
| 159 | + current_node.point = min_point; |
| 160 | + current_node.right = current_node.left.take(); |
| 161 | + delete_rec(&mut current_node.right, ¤t_node.point, depth + 1) |
| 162 | + }else { |
| 163 | + *kd_node = None; |
| 164 | + true |
| 165 | + } |
| 166 | + } else if point[axis].lt(¤t_node.point[axis]) { |
| 167 | + delete_rec(&mut current_node.left, point, depth + 1) |
| 168 | + } else { |
| 169 | + delete_rec(&mut current_node.right, point, depth + 1) |
| 170 | + } |
| 171 | + } else { |
| 172 | + false |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +/// Recursively build a kd-tree from a vector of points by taking the median of a sorted |
| 177 | +/// vector of points as the root to maintain a balance kd-tree |
| 178 | +fn build_rec<T: PartialOrd + Copy, const K: usize>(mut points: Vec<[T; K]>, depth: usize) -> Option<Box<KDNode<T, K>>> { |
| 179 | + if points.is_empty() { |
| 180 | + None |
| 181 | + } else { |
| 182 | + let axis = depth % K; |
| 183 | + points.sort_by(|a, b| a[axis].partial_cmp(&b[axis]).unwrap_or(std::cmp::Ordering::Equal)); |
| 184 | + |
| 185 | + let median: usize = points.len() / 2; |
| 186 | + let mut node: KDNode<T, K> = KDNode::new(points[median]); |
| 187 | + |
| 188 | + node.left = build_rec(points[..median].to_vec(), depth + 1); |
| 189 | + node.right = build_rec(points[median + 1..].to_vec(), depth + 1); |
| 190 | + |
| 191 | + Some(Box::new(node)) |
| 192 | + } |
| 193 | +} |
| 194 | + |
| 195 | +// Returns the depth of the deepest branch of a kd-tree. |
| 196 | +fn depth_rec<T: PartialOrd + Copy, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, left_depth: usize, right_depth: usize) -> usize { |
| 197 | + if let Some(kd_node) = kd_tree { |
| 198 | + match (&kd_node.left, &kd_node.right) { |
| 199 | + (None, None) => left_depth.max(right_depth), |
| 200 | + (None, Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth), |
| 201 | + (Some(_), None) => depth_rec(&kd_node.right, left_depth, right_depth + 1), |
| 202 | + (Some(_), Some(_)) => depth_rec(&kd_node.left, left_depth + 1, right_depth).max(depth_rec(&kd_node.right, left_depth, right_depth + 1)), |
| 203 | + } |
| 204 | + } else { |
| 205 | + left_depth.max(right_depth) |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +// Collect all points from a given KDTree into a vector |
| 210 | +fn collect_points<T: PartialOrd + Copy, const K: usize>(kd_node: &Option<Box<KDNode<T, K>>>, points: &mut Vec<[T; K]>) { |
| 211 | + if let Some(current_node) = kd_node { |
| 212 | + points.push(current_node.point); |
| 213 | + collect_points(¤t_node.left, points); |
| 214 | + collect_points(¤t_node.right, points); |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +// Calculate the distance between two points |
| 219 | +fn distance<T, const K: usize>(point_1: &[T; K], point_2: &[T; K]) -> T |
| 220 | +where |
| 221 | + T: PartialOrd + Copy + Sum + Real |
| 222 | +{ |
| 223 | + point_1 |
| 224 | + .iter() |
| 225 | + .zip(point_2.iter()) |
| 226 | + .map(|(&a, &b)| { |
| 227 | + let diff: T = a - b; |
| 228 | + diff * diff |
| 229 | + }) |
| 230 | + .sum::<T>() |
| 231 | + .sqrt() |
| 232 | +} |
| 233 | + |
| 234 | +// Returns the minimum nodes among three kd-nodes on a given axis |
| 235 | +fn min_node_on_axis<'a, T: PartialOrd + Copy, const K: usize>(kd_node: &'a KDNode<T, K>, left: Option<&'a KDNode<T, K>>, right: Option<&'a KDNode<T, K>>, axis: usize) -> &'a KDNode<T, K> { |
| 236 | + let mut current_min_node = kd_node; |
| 237 | + if let Some(left_node) = left { |
| 238 | + if left_node.point[axis].lt(¤t_min_node.point[axis]) { |
| 239 | + current_min_node = left_node; |
| 240 | + } |
| 241 | + } |
| 242 | + if let Some(right_node) = right { |
| 243 | + if right_node.point[axis].lt(¤t_min_node.point[axis]) { |
| 244 | + current_min_node = right_node; |
| 245 | + } |
| 246 | + } |
| 247 | + current_min_node |
| 248 | +} |
| 249 | + |
| 250 | +// Returns the minimum node of a kd-tree with respect to an axis |
| 251 | +fn min_node<T: PartialOrd + Copy, const K: usize>(kd_node: Option<&Box<KDNode<T, K>>>, axis: usize, depth: usize) -> Option<&KDNode<T, K>> { |
| 252 | + if let Some(current_node) = kd_node { |
| 253 | + let current_axis = depth % K; |
| 254 | + if current_axis == axis { |
| 255 | + if current_node.left.is_some() { |
| 256 | + min_node(current_node.left.as_ref(), axis, depth + 1) |
| 257 | + } else { |
| 258 | + Some(current_node) |
| 259 | + } |
| 260 | + } else { |
| 261 | + let (left_min, right_min): (Option<&KDNode<T, K>>, Option<&KDNode<T, K>>) = ( |
| 262 | + min_node(current_node.left.as_ref(), axis, depth + 1), |
| 263 | + min_node(current_node.right.as_ref(), axis, depth + 1) |
| 264 | + ); |
| 265 | + Some(min_node_on_axis(current_node, left_min, right_min, axis)) |
| 266 | + } |
| 267 | + } else { |
| 268 | + None |
| 269 | + } |
| 270 | +} |
| 271 | + |
| 272 | +// Find the nearest neighbors of a given point. The number neighbors is determine by the variable `n`. |
| 273 | +fn n_nearest_neighbors<T, const K: usize>(kd_tree: &Option<Box<KDNode<T, K>>>, point: &[T; K], n: usize, depth: usize, neighbors: &mut Vec<(T, [T; K])>) |
| 274 | +where |
| 275 | + T: PartialOrd + Copy + Sum + Real + Signed |
| 276 | +{ |
| 277 | + if let Some(kd_node) = kd_tree { |
| 278 | + let distance: T = distance(&kd_node.point, point); |
| 279 | + if neighbors.len() < n { |
| 280 | + neighbors.push((distance, kd_node.point)); |
| 281 | + } else { |
| 282 | + // safe to call unwrap() since we know our neighbors is ont empty in this scope |
| 283 | + let max_distance = neighbors.iter().max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()).unwrap().0; |
| 284 | + if distance < max_distance { |
| 285 | + if let Some(pos) = neighbors.iter().position(|x| x.0 == max_distance) { |
| 286 | + neighbors[pos] = (distance, kd_node.point); |
| 287 | + } |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + let axis: usize = depth % K; |
| 292 | + let target_axis: T = point[axis]; |
| 293 | + let split_axis: T = kd_node.point[axis]; |
| 294 | + |
| 295 | + let (look_first, look_second) = if target_axis < split_axis { |
| 296 | + (&kd_node.left, &kd_node.right) |
| 297 | + } else { |
| 298 | + (&kd_node.right, &kd_node.left) |
| 299 | + }; |
| 300 | + |
| 301 | + if look_first.is_some() { |
| 302 | + n_nearest_neighbors(&look_first, point, n, depth + 1, neighbors); |
| 303 | + } |
| 304 | + |
| 305 | + // Check if it's necessary to look on the other branch by computing the distance between our current point with the nearest point on the other branch |
| 306 | + if look_second.is_some() { |
| 307 | + let max_distance = neighbors.iter().max_by(|a, b| a.0.partial_cmp(&b.0).unwrap()).unwrap().0; |
| 308 | + if neighbors.len() < n || abs(target_axis - split_axis) < max_distance { |
| 309 | + n_nearest_neighbors(&look_second, point, n, depth + 1, neighbors); |
| 310 | + } |
| 311 | + } |
| 312 | + } |
| 313 | +} |
| 314 | + |
| 315 | + |
| 316 | +#[cfg(test)] |
| 317 | +mod test { |
| 318 | + use super::KDTree; |
| 319 | + |
| 320 | + #[test] |
| 321 | + fn insert() { |
| 322 | + let mut kd_tree: KDTree<f64, 2> = KDTree::new(); |
| 323 | + assert_eq!(kd_tree.insert([2.0, 3.0]), true); |
| 324 | + // Cannot insert the same point again |
| 325 | + assert_eq!(kd_tree.insert([2.0, 3.0]), false); |
| 326 | + assert_eq!(kd_tree.insert([2.0, 3.1]), true); |
| 327 | + } |
| 328 | + |
| 329 | + #[test] |
| 330 | + fn contains() { |
| 331 | + let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]]; |
| 332 | + let kd_tree = KDTree::build(points); |
| 333 | + assert_eq!(kd_tree.search(&[5.0, 4.0]), true); |
| 334 | + assert_eq!(kd_tree.search(&[5.0, 4.1]), false); |
| 335 | + } |
| 336 | + |
| 337 | + #[test] |
| 338 | + fn remove() { |
| 339 | + let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0]]; |
| 340 | + let mut kd_tree = KDTree::build(points); |
| 341 | + assert_eq!(kd_tree.delete(&[5.0, 4.0]), true); |
| 342 | + // Cannot remove twice |
| 343 | + assert_eq!(kd_tree.delete(&[5.0, 4.0]), false); |
| 344 | + assert_eq!(kd_tree.search(&[5.0, 4.0]), false); |
| 345 | + } |
| 346 | + |
| 347 | + #[test] |
| 348 | + fn nearest_neighbors() { |
| 349 | + let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]]; |
| 350 | + let kd_tree = KDTree::build(points); |
| 351 | + // for the point [5.0, 3.0] it's obvious that [5.0, 4.0] is one of its closest neighbor with a distance of 1.0 |
| 352 | + assert!(kd_tree.nearest_neighbors(&[5.0, 3.0], 2).contains(&(1.0, [5.0, 4.0]))); |
| 353 | + } |
| 354 | + |
| 355 | + #[test] |
| 356 | + fn is_empty() { |
| 357 | + let mut kd_tree = KDTree::new(); |
| 358 | + assert_eq!(kd_tree.is_empty(), true); |
| 359 | + kd_tree.insert([1.5, 3.0]); |
| 360 | + assert_eq!(kd_tree.is_empty(), false); |
| 361 | + } |
| 362 | + |
| 363 | + #[test] |
| 364 | + fn len_and_depth() { |
| 365 | + let points = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0], [4.0, 7.0], [8.0, 1.0], [7.0, 2.0]]; |
| 366 | + let size = points.len(); |
| 367 | + let tree = KDTree::build(points); |
| 368 | + assert_eq!(tree.len(), size); |
| 369 | + assert_eq!(tree.depth(), 2); |
| 370 | + } |
| 371 | + |
| 372 | + #[test] |
| 373 | + fn merge() { |
| 374 | + let points_1 = vec![[2.0, 3.0], [5.0, 4.0], [9.0, 6.0]]; |
| 375 | + let points_2 = vec![[4.0, 7.0], [8.0, 1.0], [7.0, 2.0]]; |
| 376 | + |
| 377 | + let mut kd_tree_1 = KDTree::build(points_1); |
| 378 | + let mut kd_tree_2 = KDTree::build(points_2); |
| 379 | + |
| 380 | + let kd_tree_3 = kd_tree_1.merge(&mut kd_tree_2); |
| 381 | + |
| 382 | + // Making sure the resulted kd-tree contains points from both kd-trees |
| 383 | + assert!(kd_tree_3.search(&[9.0, 6.0])); |
| 384 | + assert!(kd_tree_3.search(&[8.0, 1.0])); |
| 385 | + } |
| 386 | +} |
0 commit comments