Skip to content

Commit a2fe388

Browse files
authored
KDTree: api extended (#211)
1 parent aeacf3a commit a2fe388

File tree

11 files changed

+185
-45
lines changed

11 files changed

+185
-45
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## 16.6.3
4+
- KDTree:
5+
- `fromIterable` constructor added
6+
- `splitStrategy` option added to all constructors
7+
38
## 16.6.2
49
- KDTree:
510
- KDTree build optimization: split algorithm changed

benchmark/kd_tree_building.dart renamed to benchmark/kd_tree/kd_tree_building.dart

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
// 0.8 sec (MacBook Air mid 2017)
2-
import 'dart:convert';
3-
import 'dart:io';
4-
1+
// 0.5 sec (MacBook Air mid 2017)
52
import 'package:benchmark_harness/benchmark_harness.dart';
63
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
74
import 'package:ml_dataframe/ml_dataframe.dart';
5+
import 'package:ml_linalg/matrix.dart';
86

97
late DataFrame trainData;
108

@@ -24,11 +22,9 @@ class KDTreeBuildingBenchmark extends BenchmarkBase {
2422
}
2523

2624
Future main() async {
27-
final file = File('benchmark/data/sample_data.json');
28-
final dataAsString = await file.readAsString();
29-
final decoded = jsonDecode(dataAsString) as Map<String, dynamic>;
25+
final points = Matrix.random(1000, 10, seed: 1, min: -5000, max: 5000);
3026

31-
trainData = DataFrame.fromJson(decoded);
27+
trainData = DataFrame.fromMatrix(points);
3228

3329
print(
3430
'Data dimension: ${trainData.rows.length}x${trainData.rows.first.length}');

benchmark/kd_tree_querying.dart renamed to benchmark/kd_tree/kd_tree_querying.dart

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
// 0.03 sec (MacBook Air mid 2017)
2-
import 'dart:convert';
3-
import 'dart:io';
4-
1+
// 0.1 sec (MacBook Air mid 2017)
52
import 'package:benchmark_harness/benchmark_harness.dart';
63
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
4+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
75
import 'package:ml_dataframe/ml_dataframe.dart';
86
import 'package:ml_linalg/linalg.dart';
9-
import 'package:ml_linalg/vector.dart';
107

118
final k = 10;
129

@@ -30,12 +27,10 @@ class KDTreeQueryingBenchmark extends BenchmarkBase {
3027
}
3128

3229
Future main() async {
33-
final file = File('benchmark/data/sample_data.json');
34-
final dataAsString = await file.readAsString();
35-
final decodedPoints = jsonDecode(dataAsString) as Map<String, dynamic>;
30+
final points = Matrix.random(20000, 10, seed: 1, min: -5000, max: 5000);
3631

37-
trainData = DataFrame.fromJson(decodedPoints);
38-
tree = KDTree(trainData, leafSie: 1);
32+
trainData = DataFrame.fromMatrix(points);
33+
tree = KDTree(trainData, leafSize: 1);
3934
point = Vector.randomFilled(trainData.rows.first.length,
4035
seed: 10, min: -5000, max: 5000);
4136

@@ -44,4 +39,7 @@ Future main() async {
4439
print('Number of neighbours: $k');
4540

4641
KDTreeQueryingBenchmark.main();
42+
43+
print(
44+
'Amount of search iterations: ${(tree as KDTreeImpl).searchIterationCount}');
4745
}

lib/ml_algo.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ export 'package:ml_algo/src/model_selection/split_data.dart';
1313
export 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
1414
export 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart';
1515
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
16+
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';

lib/src/retrieval/kd_tree/helpers/create_kd_tree.dart

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_builder.dart';
22
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_constants.dart';
33
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
4+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
45
import 'package:ml_dataframe/ml_dataframe.dart';
56
import 'package:ml_linalg/dtype.dart';
67

7-
KDTreeImpl createKDTree(DataFrame pointsSrc, int leafSize, DType dtype) {
8+
KDTreeImpl createKDTree(DataFrame pointsSrc, int leafSize, DType dtype,
9+
KDTreeSplitStrategy splitStrategy) {
810
final points = pointsSrc.toMatrix(dtype);
9-
final builder = KDTreeBuilder(leafSize, points);
11+
final builder = KDTreeBuilder(leafSize, points, splitStrategy);
1012
final root = builder.train();
1113

1214
return KDTreeImpl(points, leafSize, root, dtype, kdTreeJsonSchemaVersion);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_builder.dart';
2+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_constants.dart';
3+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
4+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
5+
import 'package:ml_linalg/dtype.dart';
6+
import 'package:ml_linalg/matrix.dart';
7+
8+
KDTreeImpl createKDTreeFromIterable(Iterable<Iterable<num>> pointsSrc,
9+
int leafSize, DType dtype, KDTreeSplitStrategy splitStrategy) {
10+
final points = Matrix.fromList(
11+
pointsSrc
12+
.map((row) => row.map((element) => element.toDouble()).toList())
13+
.toList(),
14+
dtype: dtype);
15+
final builder = KDTreeBuilder(leafSize, points, splitStrategy);
16+
final root = builder.train();
17+
18+
return KDTreeImpl(points, leafSize, root, dtype, kdTreeJsonSchemaVersion);
19+
}

lib/src/retrieval/kd_tree/kd_tree.dart

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,83 @@
11
import 'package:ml_algo/src/common/serializable/serializable.dart';
22
import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree.dart';
3+
import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree_from_iterable.dart';
34
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
45
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart';
6+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
57
import 'package:ml_dataframe/ml_dataframe.dart';
68
import 'package:ml_linalg/dtype.dart';
79
import 'package:ml_linalg/matrix.dart';
810
import 'package:ml_linalg/vector.dart';
911

10-
/// KD-tree - an algorithm that provides efficient data retrieval. It splits
11-
/// the whole searching space into partitions in binary tree form which means
12+
/// KD-tree - an algorithm that provides efficient data retrieval by splitting
13+
/// the whole searching space into partitions in form of binary tree which means
1214
/// that data querying on average will take O(log(n)) time
15+
///
16+
/// One can use this algorithm to perform KNN-search. It's recommended to use
17+
/// [KDTree] when the number of the input data columns is much less than the
18+
/// number of rows of the data - in this case, the search will be more efficient
1319
abstract class KDTree implements Serializable {
20+
/// [points] Data points which will be used to build the tree.
21+
///
22+
/// [leafSize] A number of points on a leaf node.
23+
///
24+
/// The bigger the number, the less effective search is. If [leafSize] is
25+
/// equal to the number of [points], a regular KNN-search will take place.
26+
///
27+
/// Extremely small [leafSize] leads to ineffective memory usage since in
28+
/// this case a lot of kd-tree nodes will be allocated
29+
///
30+
/// [dtype] A data type which will be used to convert raw data from [points]
31+
/// into internal numerical representation
32+
///
33+
/// [splitStrategy] Describes how to choose a split dimension. Default value
34+
/// is [KDTreeSplitStrategy.largestVariance]
35+
///
36+
/// if [splitStrategy] is [KDTreeSplitStrategy.largestVariance], dimension with
37+
/// the widest column (in terms of variance) will be chosen to split the data
38+
///
39+
/// if [splitStrategy] is [KDTreeSplitStrategy.inOrder], dimension for data
40+
/// splits will be chosen one by one in order
41+
///
42+
/// [KDTreeSplitStrategy.largestVariance] provides more accurate KNN-search,
43+
/// but this strategy takes much more time to build the tree than [KDTreeSplitStrategy.inOrder]
1444
factory KDTree(DataFrame points,
15-
{int leafSie = 10, DType dtype = DType.float32}) =>
16-
createKDTree(points, leafSie, dtype);
45+
{int leafSize = 1,
46+
DType dtype = DType.float32,
47+
KDTreeSplitStrategy splitStrategy =
48+
KDTreeSplitStrategy.largestVariance}) =>
49+
createKDTree(points, leafSize, dtype, splitStrategy);
50+
51+
/// [pointsSrc] Data points which will be used to build the tree.
52+
///
53+
/// [leafSize] A number of points on a leaf node.
54+
///
55+
/// The bigger the number, the less effective search is. If [leafSize] is
56+
/// equal to the number of [pointsSrc], a regular KNN-search will take place.
57+
///
58+
/// Extremely small [leafSize] leads to ineffective memory usage since in
59+
/// this case a lot of kd-tree nodes will be allocated
60+
///
61+
/// [dtype] A data type which will be used to convert raw data from [points]
62+
/// into internal numerical representation
63+
///
64+
/// [splitStrategy] Describes how to choose a split dimension. Default value
65+
/// is [KDTreeSplitStrategy.largestVariance]
66+
///
67+
/// if [splitStrategy] is [KDTreeSplitStrategy.largestVariance], dimension with
68+
/// the widest column (in terms of variance) will be chosen to split the data
69+
///
70+
/// if [splitStrategy] is [KDTreeSplitStrategy.inOrder], dimension for data
71+
/// splits will be chosen one by one in order
72+
///
73+
/// [KDTreeSplitStrategy.largestVariance] provides more accurate KNN-search,
74+
/// but this strategy takes much more time to build the tree than [KDTreeSplitStrategy.inOrder]
75+
factory KDTree.fromIterable(Iterable<Iterable<num>> pointsSrc,
76+
{int leafSize = 1,
77+
DType dtype = DType.float32,
78+
KDTreeSplitStrategy splitStrategy =
79+
KDTreeSplitStrategy.largestVariance}) =>
80+
createKDTreeFromIterable(pointsSrc, leafSize, dtype, splitStrategy);
1781

1882
factory KDTree.fromJson(Map<String, dynamic> json) =>
1983
KDTreeImpl.fromJson(json);
@@ -30,7 +94,7 @@ abstract class KDTree implements Serializable {
3094
/// this case a lot of kd-tree nodes will be allocated
3195
int get leafSize;
3296

33-
/// Data type for [points] matrix
97+
/// Data type for internal representation of [points]
3498
DType get dtype;
3599

36100
/// Returns [k] nearest neighbours for [point]

lib/src/retrieval/kd_tree/kd_tree_builder.dart

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_node.dart';
2+
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
23
import 'package:ml_linalg/matrix.dart';
34

45
class _Split {
@@ -10,17 +11,20 @@ class _Split {
1011
}
1112

1213
class KDTreeBuilder {
13-
KDTreeBuilder(this._leafSize, this._points);
14+
KDTreeBuilder(this._leafSize, this._points, this._splitStrategy);
1415

1516
final int _leafSize;
1617
final Matrix _points;
18+
final KDTreeSplitStrategy _splitStrategy;
1719

18-
KDTreeNode train() => _train(_points.rowIndices.toList());
20+
KDTreeNode train() => _train(_points.rowIndices.toList(), 0);
1921

20-
KDTreeNode _train(List<int> pointIndices) {
22+
KDTreeNode _train(List<int> pointIndices, int splitDim) {
2123
final isLeaf = pointIndices.length <= _leafSize;
2224
final points = _points.sample(rowIndices: pointIndices);
23-
final splitIdx = _getSplitIdx(points);
25+
final splitIdx = _splitStrategy == KDTreeSplitStrategy.largestVariance
26+
? _getSplitIdx(points)
27+
: splitDim % _points.columnsNum;
2428

2529
if (isLeaf) {
2630
return KDTreeNode(splitIndex: splitIdx, pointIndices: pointIndices);
@@ -31,8 +35,8 @@ class KDTreeBuilder {
3135
return KDTreeNode(
3236
pointIndices: [split.midPoint],
3337
splitIndex: splitIdx,
34-
left: _train(split.left),
35-
right: _train(split.right),
38+
left: _train(split.left, splitDim + 1),
39+
right: _train(split.right, splitDim + 1),
3640
);
3741
}
3842

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
enum KDTreeSplitStrategy {
2+
largestVariance,
3+
inOrder,
4+
}

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.6.2
3+
version: 16.6.3
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

0 commit comments

Comments
 (0)