Skip to content

Commit e189c7c

Browse files
authored
Added KDTree algorithm (#208)
1 parent 71c1832 commit e189c7c

20 files changed

+699
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 16.6.0
4+
- Added `KDTree` algorithm
5+
36
## 16.5.2
47
- Add ecosystem notes to `README.md`
58

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ it in the web applications.
6868
A class that makes prediction for each new observation basing on first `k` closest observations from
6969
training data. It may catch non-linear pattern of the data.
7070

71+
- #### Clustering and retrieval algorithms
72+
- [KDTree](https://github.com/gyrdym/ml_algo/blob/master/lib/src/retrieval/kd_tree/kd_tree.dart)
73+
7174
For more information on the library's API, please visit [API reference](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html)
7275

7376
## Examples
File renamed without changes.

benchmark/kd_tree_building.dart

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// 0.5 sec (MacBook Air mid 2017)
2+
import 'dart:convert';
3+
import 'dart:io';
4+
5+
import 'package:benchmark_harness/benchmark_harness.dart';
6+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
7+
import 'package:ml_dataframe/ml_dataframe.dart';
8+
9+
late DataFrame trainData;
10+
11+
class KDTreeBuildingBenchmark extends BenchmarkBase {
12+
KDTreeBuildingBenchmark() : super('KDTree building benchmark');
13+
14+
static void main() {
15+
KDTreeBuildingBenchmark().report();
16+
}
17+
18+
@override
19+
void run() {
20+
KDTree(trainData);
21+
}
22+
23+
void tearDown() {}
24+
}
25+
26+
Future main() async {
27+
final file = File('benchmark/data/sample_data.json');
28+
final dataAsString = await file.readAsString();
29+
final decoded = jsonDecode(dataAsString) as Map<String, dynamic>;
30+
31+
trainData = DataFrame.fromJson(decoded);
32+
33+
print(
34+
'Data dimension: ${trainData.rows.length}x${trainData.rows.first.length}');
35+
36+
KDTreeBuildingBenchmark.main();
37+
}

benchmark/kd_tree_querying.dart

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// 0.04 sec (MacBook Air mid 2017)
2+
import 'dart:convert';
3+
import 'dart:io';
4+
5+
import 'package:benchmark_harness/benchmark_harness.dart';
6+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
7+
import 'package:ml_dataframe/ml_dataframe.dart';
8+
import 'package:ml_linalg/linalg.dart';
9+
import 'package:ml_linalg/vector.dart';
10+
11+
final k = 10;
12+
13+
late DataFrame trainData;
14+
late KDTree tree;
15+
late Vector point;
16+
17+
class KDTreeQueryingBenchmark extends BenchmarkBase {
18+
KDTreeQueryingBenchmark() : super('KDTree querying benchmark');
19+
20+
static void main() {
21+
KDTreeQueryingBenchmark().report();
22+
}
23+
24+
@override
25+
void run() {
26+
tree.query(point, k);
27+
}
28+
29+
void tearDown() {}
30+
}
31+
32+
Future main() async {
33+
final file = File('benchmark/data/sample_data.json');
34+
final dataAsString = await file.readAsString();
35+
final decodedPoints = jsonDecode(dataAsString) as Map<String, dynamic>;
36+
37+
trainData = DataFrame.fromJson(decodedPoints);
38+
tree = KDTree(trainData);
39+
point = Vector.randomFilled(trainData.rows.first.length,
40+
seed: 10, min: -5000, max: 5000);
41+
42+
print(
43+
'Data dimension: ${trainData.rows.length}x${trainData.rows.first.length}');
44+
print('Number of neighbours: $k');
45+
46+
KDTreeQueryingBenchmark.main();
47+
}

benchmark/lasso_regressor.dart

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import 'package:benchmark_harness/benchmark_harness.dart';
55
import 'package:ml_algo/ml_algo.dart';
66
import 'package:ml_dataframe/ml_dataframe.dart';
77

8-
const observationsNum = 1000;
9-
const featuresNum = 100;
108
late DataFrame trainData;
119

1210
class LassoRegressorBenchmark extends BenchmarkBase {
@@ -23,7 +21,7 @@ class LassoRegressorBenchmark extends BenchmarkBase {
2321
}
2422

2523
Future main() async {
26-
final file = File('benchmark/data/sample_regression_data.json');
24+
final file = File('benchmark/data/sample_data.json');
2725
final dataAsString = await file.readAsString();
2826
final decoded = jsonDecode(dataAsString) as Map<String, dynamic>;
2927

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class InvalidQueryPointLength implements Exception {
2+
InvalidQueryPointLength(int pointLength, int expectedLength)
3+
: message =
4+
'Invalid query point length: expected length is $expectedLength, but given point\'s length is $pointLength';
5+
6+
final String message;
7+
8+
@override
9+
String toString() => message;
10+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_builder.dart';
2+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_constants.dart';
3+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
4+
import 'package:ml_dataframe/ml_dataframe.dart';
5+
import 'package:ml_linalg/dtype.dart';
6+
7+
KDTreeImpl createKDTree(DataFrame pointsSrc, int leafSize, DType dtype) {
8+
final points = pointsSrc.toMatrix(dtype);
9+
final builder = KDTreeBuilder(leafSize, points);
10+
final root = builder.train();
11+
12+
return KDTreeImpl(points, leafSize, root, dtype, kdTreeJsonSchemaVersion);
13+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import 'package:ml_algo/src/common/serializable/serializable.dart';
2+
import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree.dart';
3+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
4+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart';
5+
import 'package:ml_dataframe/ml_dataframe.dart';
6+
import 'package:ml_linalg/dtype.dart';
7+
import 'package:ml_linalg/matrix.dart';
8+
import 'package:ml_linalg/vector.dart';
9+
10+
/// KD-tree - an algorithm that provides efficient data retrieval. It splits
11+
/// the whole searching space into partitions in binary tree form which means
12+
/// that data querying on average will take O(log(n)) time
13+
abstract class KDTree implements Serializable {
14+
factory KDTree(DataFrame points,
15+
{int leafSie = 10, DType dtype = DType.float32}) =>
16+
createKDTree(points, leafSie, dtype);
17+
18+
factory KDTree.fromJson(Map<String, dynamic> json) =>
19+
KDTreeImpl.fromJson(json);
20+
21+
/// Points which were used to build the kd-tree
22+
Matrix get points;
23+
24+
/// A number of points on a leaf node.
25+
///
26+
/// The bigger the number, the less effective search is. If [leafSize] is
27+
/// equal to the number of [points], a regular KNN-search will take place.
28+
///
29+
/// Extremely small [leafSize] leads to ineffective memory usage since in
30+
/// this case a lot of kd-tree nodes will be allocated
31+
int get leafSize;
32+
33+
/// Data type for [points] matrix
34+
DType get dtype;
35+
36+
/// Returns [k] nearest neighbours for [point]
37+
///
38+
/// The neighbour is represented by an index and the distance between [point]
39+
/// and the neighbour itself. The index is a zero-based index of a point in
40+
/// the source [points] matrix. Example:
41+
///
42+
/// ```dart
43+
/// final data = DataFrame([
44+
/// [21, 34, 22, 11],
45+
/// [11, 33, 44, 55],
46+
/// ...,
47+
/// ], headerExists: false);
48+
/// final kdTree = KDTree(data);
49+
/// final neighbours = kdTree.query([1, 2, 3, 4], 2);
50+
///
51+
/// print(neighbours.index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
52+
/// ```
53+
Iterable<KDTreeNeighbour> query(Vector point, int k);
54+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_node.dart';
2+
import 'package:ml_linalg/matrix.dart';
3+
4+
class _Split {
5+
_Split(this.left, this.right, this.midPoint);
6+
7+
final List<int> left;
8+
final List<int> right;
9+
final int midPoint;
10+
}
11+
12+
class KDTreeBuilder {
13+
KDTreeBuilder(this._leafSize, this._points);
14+
15+
final int _leafSize;
16+
final Matrix _points;
17+
18+
KDTreeNode train() => _train(_points.rowIndices.toList());
19+
20+
KDTreeNode _train(List<int> pointIndices) {
21+
final isLeaf = pointIndices.length <= _leafSize;
22+
23+
if (isLeaf) {
24+
return KDTreeNode(pointIndices: pointIndices);
25+
}
26+
27+
final points = _points.sample(rowIndices: pointIndices);
28+
final splitIdx = _getSplitIdx(points);
29+
final splitValue = points.getColumn(splitIdx).median();
30+
final split = _splitPoints(pointIndices, splitIdx, splitValue);
31+
32+
return KDTreeNode(
33+
pointIndices: [split.midPoint],
34+
splitIndex: splitIdx,
35+
left: _train(split.left),
36+
right: _train(split.right),
37+
);
38+
}
39+
40+
int _getSplitIdx(Matrix points) {
41+
final variances = points.variance();
42+
43+
var colIdx = 0;
44+
var maxIdx = colIdx;
45+
var max = variances[maxIdx];
46+
47+
variances.forEach((variance) {
48+
if (variance > max) {
49+
max = variance;
50+
maxIdx = colIdx;
51+
}
52+
colIdx++;
53+
});
54+
55+
return maxIdx;
56+
}
57+
58+
_Split _splitPoints(List<int> pointIndices, int splitIdx, num splitValue) {
59+
final left = <int>[];
60+
final right = <int>[];
61+
int? midPoint;
62+
63+
for (var i = 0; i < pointIndices.length; i++) {
64+
final pointIndex = pointIndices[i];
65+
final point = _points[pointIndex];
66+
67+
if (point[splitIdx] < splitValue) {
68+
left.add(pointIndex);
69+
continue;
70+
}
71+
72+
if (midPoint == null || point[splitIdx] < _points[midPoint][splitIdx]) {
73+
midPoint = pointIndex;
74+
}
75+
76+
right.add(pointIndex);
77+
}
78+
79+
return _Split(left, right, midPoint!);
80+
}
81+
}

0 commit comments

Comments
 (0)