|
| 1 | +import 'package:ml_algo/src/common/serializable/serializable.dart'; |
| 2 | +import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree.dart'; |
| 3 | +import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart'; |
| 4 | +import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart'; |
| 5 | +import 'package:ml_dataframe/ml_dataframe.dart'; |
| 6 | +import 'package:ml_linalg/dtype.dart'; |
| 7 | +import 'package:ml_linalg/matrix.dart'; |
| 8 | +import 'package:ml_linalg/vector.dart'; |
| 9 | + |
| 10 | +/// KD-tree - an algorithm that provides efficient data retrieval. It splits |
| 11 | +/// the whole searching space into partitions in binary tree form which means |
| 12 | +/// that data querying on average will take O(log(n)) time |
| 13 | +abstract class KDTree implements Serializable { |
| 14 | + factory KDTree(DataFrame points, |
| 15 | + {int leafSie = 10, DType dtype = DType.float32}) => |
| 16 | + createKDTree(points, leafSie, dtype); |
| 17 | + |
| 18 | + factory KDTree.fromJson(Map<String, dynamic> json) => |
| 19 | + KDTreeImpl.fromJson(json); |
| 20 | + |
| 21 | + /// Points which were used to build the kd-tree |
| 22 | + Matrix get points; |
| 23 | + |
| 24 | + /// A number of points on a leaf node. |
| 25 | + /// |
| 26 | + /// The bigger the number, the less effective search is. If [leafSize] is |
| 27 | + /// equal to the number of [points], a regular KNN-search will take place. |
| 28 | + /// |
| 29 | + /// Extremely small [leafSize] leads to ineffective memory usage since in |
| 30 | + /// this case a lot of kd-tree nodes will be allocated |
| 31 | + int get leafSize; |
| 32 | + |
| 33 | + /// Data type for [points] matrix |
| 34 | + DType get dtype; |
| 35 | + |
| 36 | + /// Returns [k] nearest neighbours for [point] |
| 37 | + /// |
| 38 | + /// The neighbour is represented by an index and the distance between [point] |
| 39 | + /// and the neighbour itself. The index is a zero-based index of a point in |
| 40 | + /// the source [points] matrix. Example: |
| 41 | + /// |
| 42 | + /// ```dart |
| 43 | + /// final data = DataFrame([ |
| 44 | + /// [21, 34, 22, 11], |
| 45 | + /// [11, 33, 44, 55], |
| 46 | + /// ..., |
| 47 | + /// ], headerExists: false); |
| 48 | + /// final kdTree = KDTree(data); |
| 49 | + /// final neighbours = kdTree.query([1, 2, 3, 4], 2); |
| 50 | + /// |
| 51 | + /// print(neighbours.index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3] |
| 52 | + /// ``` |
| 53 | + Iterable<KDTreeNeighbour> query(Vector point, int k); |
| 54 | +} |
0 commit comments