Skip to content

Commit aeacf3a

Browse files
authored
KDTree building optimisation: split algorithm changed (#210)
1 parent df8c28c commit aeacf3a

File tree

11 files changed

+137
-96
lines changed

11 files changed

+137
-96
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.6.2
4+
- KDTree:
5+
- KDTree build optimization: split algorithm changed
6+
37
## 16.6.1
48
- `KDTree` class added to library export file
59

benchmark/kd_tree_building.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// 0.5 sec (MacBook Air mid 2017)
1+
// 0.8 sec (MacBook Air mid 2017)
22
import 'dart:convert';
33
import 'dart:io';
44

benchmark/kd_tree_querying.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// 0.04 sec (MacBook Air mid 2017)
1+
// 0.03 sec (MacBook Air mid 2017)
22
import 'dart:convert';
33
import 'dart:io';
44

@@ -35,7 +35,7 @@ Future main() async {
3535
final decodedPoints = jsonDecode(dataAsString) as Map<String, dynamic>;
3636

3737
trainData = DataFrame.fromJson(decodedPoints);
38-
tree = KDTree(trainData);
38+
tree = KDTree(trainData, leafSie: 1);
3939
point = Vector.randomFilled(trainData.rows.first.length,
4040
seed: 10, min: -5000, max: 5000);
4141

lib/src/retrieval/kd_tree/kd_tree.dart

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,16 @@ abstract class KDTree implements Serializable {
4040
/// the source [points] matrix. Example:
4141
///
4242
/// ```dart
43+
/// import 'package:ml_dataframe/ml_dataframe.dart';
44+
/// import 'package:ml_linalg/vector.dart';
45+
///
4346
/// final data = DataFrame([
4447
/// [21, 34, 22, 11],
4548
/// [11, 33, 44, 55],
4649
/// ...,
4750
/// ], headerExists: false);
4851
/// final kdTree = KDTree(data);
49-
/// final neighbours = kdTree.query([1, 2, 3, 4], 2);
52+
/// final neighbours = kdTree.query(Vector.fromList([1, 2, 3, 4]), 2);
5053
///
5154
/// print(neighbours.index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
5255
/// ```

lib/src/retrieval/kd_tree/kd_tree_builder.dart

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ class KDTreeBuilder {
1919

2020
KDTreeNode _train(List<int> pointIndices) {
2121
final isLeaf = pointIndices.length <= _leafSize;
22+
final points = _points.sample(rowIndices: pointIndices);
23+
final splitIdx = _getSplitIdx(points);
2224

2325
if (isLeaf) {
24-
return KDTreeNode(pointIndices: pointIndices);
26+
return KDTreeNode(splitIndex: splitIdx, pointIndices: pointIndices);
2527
}
2628

27-
final points = _points.sample(rowIndices: pointIndices);
28-
final splitIdx = _getSplitIdx(points);
29-
final splitValue = points.getColumn(splitIdx).median();
30-
final split = _splitPoints(pointIndices, splitIdx, splitValue);
29+
final split = _splitPoints(pointIndices, splitIdx);
3130

3231
return KDTreeNode(
3332
pointIndices: [split.midPoint],
@@ -55,27 +54,15 @@ class KDTreeBuilder {
5554
return maxIdx;
5655
}
5756

58-
_Split _splitPoints(List<int> pointIndices, int splitIdx, num splitValue) {
59-
final left = <int>[];
60-
final right = <int>[];
61-
int? midPoint;
62-
63-
for (var i = 0; i < pointIndices.length; i++) {
64-
final pointIndex = pointIndices[i];
65-
final point = _points[pointIndex];
57+
_Split _splitPoints(List<int> pointIndices, int splitIdx) {
58+
pointIndices.sort((firstIdx, secondIdx) =>
59+
_points[firstIdx][splitIdx].compareTo(_points[secondIdx][splitIdx]));
6660

67-
if (point[splitIdx] < splitValue) {
68-
left.add(pointIndex);
69-
continue;
70-
}
71-
72-
if (midPoint == null || point[splitIdx] < _points[midPoint][splitIdx]) {
73-
midPoint = pointIndex;
74-
}
75-
76-
right.add(pointIndex);
77-
}
61+
final midPointIdx = (pointIndices.length / 2).floor();
62+
final midPoint = pointIndices[midPointIdx];
63+
final left = pointIndices.sublist(0, midPointIdx);
64+
final right = pointIndices.sublist(midPointIdx + 1);
7865

79-
return _Split(left, right, midPoint!);
66+
return _Split(left, right, midPoint);
8067
}
8168
}

lib/src/retrieval/kd_tree/kd_tree_impl.dart

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ class KDTreeImpl with SerializableMixin implements KDTree {
6161
return neighbours.toList().reversed;
6262
}
6363

64-
void _findKNNRecursively(KDTreeNode node, Vector point, int k,
64+
void _findKNNRecursively(KDTreeNode? node, Vector point, int k,
6565
HeapPriorityQueue<KDTreeNeighbour> neighbours) {
66-
searchIterationCount++;
66+
if (node == null) {
67+
return;
68+
}
6769

6870
if (node.isLeaf) {
6971
_knnSearch(point, node.pointIndices, neighbours, k);
@@ -72,31 +74,30 @@ class KDTreeImpl with SerializableMixin implements KDTree {
7274
}
7375

7476
final nodePoint = points[node.pointIndices[0]];
75-
final boundaryPoint = _getBoundary(point, node);
7677
final isNodeTooFar = neighbours.length > 0 &&
77-
boundaryPoint.distanceTo(nodePoint) > neighbours.first.distance;
78+
(point[node.splitIndex] - nodePoint[node.splitIndex]).abs() >
79+
neighbours.first.distance;
7880
final isQueueFilled = neighbours.length == k;
7981

8082
if (isQueueFilled && isNodeTooFar) {
8183
return;
8284
}
8385

84-
if (node.left == null) {
85-
_findKNNRecursively(node.right!, point, k, neighbours);
86-
} else if (node.right == null) {
87-
_findKNNRecursively(node.left!, point, k, neighbours);
88-
} else if (point[node.splitIndex!] < nodePoint[node.splitIndex!]) {
89-
_findKNNRecursively(node.left!, point, k, neighbours);
90-
_findKNNRecursively(node.right!, point, k, neighbours);
86+
_knnSearch(point, node.pointIndices, neighbours, k);
87+
88+
if (point[node.splitIndex] < nodePoint[node.splitIndex]) {
89+
_findKNNRecursively(node.left, point, k, neighbours);
90+
_findKNNRecursively(node.right, point, k, neighbours);
9191
} else {
92-
_findKNNRecursively(node.right!, point, k, neighbours);
93-
_findKNNRecursively(node.left!, point, k, neighbours);
92+
_findKNNRecursively(node.right, point, k, neighbours);
93+
_findKNNRecursively(node.left, point, k, neighbours);
9494
}
9595
}
9696

9797
void _knnSearch(Vector point, List<int> pointIndices,
9898
HeapPriorityQueue<KDTreeNeighbour> neighbours, int k) {
9999
pointIndices.forEach((candidateIdx) {
100+
searchIterationCount++;
100101
final candidate = points[candidateIdx];
101102
final candidateDistance = candidate.distanceTo(point);
102103
final lastNeighbourDistance =
@@ -113,13 +114,4 @@ class KDTreeImpl with SerializableMixin implements KDTree {
113114
}
114115
});
115116
}
116-
117-
Vector _getBoundary(Vector point, KDTreeNode node) {
118-
final nodePoint = points[node.pointIndices[0]];
119-
final boundarySrc = [...nodePoint];
120-
121-
boundarySrc[node.splitIndex!] = point[node.splitIndex!];
122-
123-
return Vector.fromList(boundarySrc, dtype: point.dtype);
124-
}
125117
}

lib/src/retrieval/kd_tree/kd_tree_neighbour.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ class KDTreeNeighbour {
1414
}
1515

1616
@override
17-
String toString() => 'Index: $index, Distance: $distance';
17+
String toString() => '(Index: $index, Distance: $distance)';
1818
}

lib/src/retrieval/kd_tree/kd_tree_node.dart

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@ part 'kd_tree_node.g.dart';
77
@JsonSerializable()
88
class KDTreeNode {
99
KDTreeNode(
10-
{this.splitIndex, this.left, this.right, required this.pointIndices});
10+
{required this.splitIndex,
11+
this.left,
12+
this.right,
13+
required this.pointIndices});
1114

1215
factory KDTreeNode.fromJson(Map<String, dynamic> json) =>
1316
_$KDTreeNodeFromJson(json);
1417

1518
Map<String, dynamic> toJson() => _$KDTreeNodeToJson(this);
1619

1720
@JsonKey(name: kdTreeNodeIndexJsonKey)
18-
final int? splitIndex;
21+
final int splitIndex;
1922

2023
@JsonKey(name: kdTreeNodeLeftJsonKey)
2124
final KDTreeNode? left;

lib/src/retrieval/kd_tree/kd_tree_node.g.dart

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.6.1
3+
version: 16.6.2
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

0 commit comments

Comments
 (0)