Skip to content

Commit e47d6ef

Browse files
authored
'Random Binary Projection' algorithm added (#228)
1 parent 1ced4c0 commit e47d6ef

27 files changed

+931
-85
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 16.12.0
4+
- `RandomBinaryProjectionSearcher` class added
5+
36
## 16.11.4
47
- `getPimaIndiansDiabetesDataFrame`, `getIrisDataFrame` used
58

README.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,48 @@ it in web applications.
3535
## The library content
3636

3737
- #### Model selection
38-
- [CrossValidator](https://github.com/gyrdym/ml_algo/blob/master/lib/src/model_selection/cross_validator/cross_validator.dart).
38+
- [CrossValidator](https://pub.dev/documentation/ml_algo/latest/ml_algo/CrossValidator-class.html).
3939
A factory that creates instances of cross validators. Cross-validation allows researchers to fit different
4040
[hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) of machine learning algorithms
4141
assessing prediction quality on different parts of a dataset.
4242

4343
- #### Classification algorithms
44-
- [LogisticRegressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/logistic_regressor/logistic_regressor.dart).
44+
- [LogisticRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor-class.html).
4545
A class that performs linear binary classification of data. To use this kind of classifier your data has to be
4646
[linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
4747

48-
- [SoftmaxRegressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/softmax_regressor/softmax_regressor.dart).
48+
- [SoftmaxRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/SoftmaxRegressor-class.html).
4949
A class that performs linear multiclass classification of data. To use this kind of classifier your data has to be
5050
[linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
5151

52-
- [DecisionTreeClassifier](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart)
52+
- [DecisionTreeClassifier](https://pub.dev/documentation/ml_algo/latest/ml_algo/DecisionTreeClassifier-class.html)
5353
A class that performs classification using decision trees. May work with data with non-linear patterns.
5454

55-
- [KnnClassifier](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/knn_classifier/knn_classifier.dart)
55+
- [KnnClassifier](https://pub.dev/documentation/ml_algo/latest/ml_algo/KnnClassifier-class.html)
5656
A class that performs classification using `k nearest neighbours algorithm` - it makes predictions based on
5757
the first `k` closest observations to the given one.
5858

5959
- #### Regression algorithms
60-
- [LinearRegressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/regressor/linear_regressor/linear_regressor.dart).
60+
- [LinearRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor-class.html).
6161
A general class for finding a linear pattern in training data and predicting outcomes as real numbers.
6262

63-
- [LinearRegressor.lasso](https://github.com/gyrdym/ml_algo/blob/85f1e2f19b946beb2b594a62e0e3c999d1c31608/lib/src/regressor/linear_regressor/linear_regressor.dart#L219)
63+
- [LinearRegressor.lasso](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.lasso.html)
6464
Implementation of the linear regression algorithm based on coordinate descent with lasso regularisation
6565

66-
- [LinearRegressor.SGD](https://github.com/gyrdym/ml_algo/blob/c0ffc71676c1ad14927448fe9bbf984a425ce27a/lib/src/regressor/linear_regressor/linear_regressor.dart#L322)
66+
- [LinearRegressor.SGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.SGD.html)
6767
Implementation of the linear regression algorithm based on stochastic gradient descent with L2 regularisation
6868

69-
- [KnnRegressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/regressor/knn_regressor/knn_regressor.dart)
69+
- [KnnRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/KnnRegressor-class.html)
7070
A class that makes predictions for each new observation based on the first `k` closest observations from
7171
training data. It may catch non-linear patterns of the data.
7272

7373
- #### Clustering and retrieval algorithms
74-
- [KDTree](https://github.com/gyrdym/ml_algo/blob/master/lib/src/retrieval/kd_tree/kd_tree.dart) An algorithm for
74+
- [KDTree](https://pub.dev/documentation/ml_algo/latest/kd_tree/KDTree-class.html) An algorithm for
7575
efficient data retrieval.
76+
- **Locality sensitive hashing.** A family of algorithms that randomly partition all reference data points into
77+
different bins, which makes it possible to perform efficient K Nearest Neighbours search, since there is no need
78+
to search for the neighbours through the entire data. The family is represented by the following classes:
79+
- [RandomBinaryProjectionSearcher](https://pub.dev/documentation/ml_algo/latest/random_binary_projection_searcher/RandomBinaryProjectionSearcher-class.html)
7680

7781
For more information on the library's API, please visit the [API reference](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html)
7882

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// 0.03 sec (MacBook Air mid 2017)
2+
import 'package:benchmark_harness/benchmark_harness.dart';
3+
import 'package:ml_algo/ml_algo.dart';
4+
import 'package:ml_dataframe/ml_dataframe.dart';
5+
import 'package:ml_linalg/matrix.dart';
6+
7+
late DataFrame trainData;
8+
9+
class RandomBinaryProjectionSearcherBuildingBenchmark extends BenchmarkBase {
10+
RandomBinaryProjectionSearcherBuildingBenchmark()
11+
: super('RandomBinaryProjectionSearcher building benchmark');
12+
13+
static void main() {
14+
RandomBinaryProjectionSearcherBuildingBenchmark().report();
15+
}
16+
17+
@override
18+
void run() {
19+
RandomBinaryProjectionSearcher(trainData, 4, seed: 10);
20+
}
21+
22+
void tearDown() {}
23+
}
24+
25+
Future main() async {
26+
final points = Matrix.random(1000, 10, seed: 1, min: -5000, max: 5000);
27+
28+
trainData = DataFrame.fromMatrix(points);
29+
30+
print(
31+
'Data dimension: ${trainData.rows.length}x${trainData.rows.first.length}');
32+
33+
RandomBinaryProjectionSearcherBuildingBenchmark.main();
34+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// 0.04 sec (MacBook Air mid 2017)
2+
import 'package:benchmark_harness/benchmark_harness.dart';
3+
import 'package:ml_algo/ml_algo.dart';
4+
import 'package:ml_algo/src/retrieval/random_binary_projection_searcher/random_binary_projection_searcher_impl.dart';
5+
import 'package:ml_dataframe/ml_dataframe.dart';
6+
import 'package:ml_linalg/linalg.dart';
7+
8+
final k = 10;
9+
final digitCapacity = 10;
10+
final searchRadius = 3;
11+
12+
late DataFrame trainData;
13+
late RandomBinaryProjectionSearcher searcher;
14+
late Vector point;
15+
16+
class RandomBinaryProjectionSearcherQueryingBenchmark extends BenchmarkBase {
17+
RandomBinaryProjectionSearcherQueryingBenchmark()
18+
: super('RandomBinaryProjectionSearcher querying benchmark');
19+
20+
static void main() {
21+
RandomBinaryProjectionSearcherQueryingBenchmark().report();
22+
}
23+
24+
@override
25+
void run() {
26+
searcher.query(point, k, searchRadius);
27+
}
28+
29+
void tearDown() {}
30+
}
31+
32+
Future main() async {
33+
final points = Matrix.random(20000, 10, seed: 1, min: -5000, max: 5000);
34+
35+
trainData = DataFrame.fromMatrix(points);
36+
searcher = RandomBinaryProjectionSearcher(trainData, digitCapacity, seed: 10);
37+
point = Vector.randomFilled(trainData.rows.first.length,
38+
seed: 10, min: -5000, max: 5000);
39+
40+
print(
41+
'Data dimension: ${trainData.rows.length}x${trainData.rows.first.length}');
42+
print('Number of neighbours: $k');
43+
44+
RandomBinaryProjectionSearcherQueryingBenchmark.main();
45+
46+
print(
47+
'Amount of search iterations: ${(searcher as RandomBinaryProjectionSearcherImpl).searchIterationCount}');
48+
}

e2e/random_binary_projection_searcher/random_binary_projection_searcher_32_v1.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

e2e/random_binary_projection_searcher/random_binary_projection_searcher_64_v1.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import 'dart:io';
2+
3+
import 'package:ml_algo/ml_algo.dart';
4+
import 'package:ml_linalg/dtype.dart';
5+
import 'package:ml_linalg/vector.dart';
6+
import 'package:test/test.dart';
7+
8+
void main() async {
9+
group('RandomBinaryProjectionSearcher', () {
10+
test('should restore from a JSON file, dtype=DType.float32', () async {
11+
final jsonSource = await File(
12+
'e2e/random_binary_projection_searcher/random_binary_projection_searcher_32_v1.json')
13+
.readAsString();
14+
final searcher = RandomBinaryProjectionSearcher.fromJson(jsonSource);
15+
final k = 5;
16+
final searchRadius = 3;
17+
final neighbours = searcher.query(
18+
Vector.fromList(
19+
[6.5, 3.01, 4.5, 1.5, 2.3, 22.3, 14.09, 2.9, 22.0, 11.22]),
20+
k,
21+
searchRadius);
22+
23+
expect(searcher.columns, [
24+
'feature_1',
25+
'feature_2',
26+
'feature_3',
27+
'feature_4',
28+
'feature_5',
29+
'feature_6',
30+
'feature_7',
31+
'feature_8',
32+
'feature_9',
33+
'feature_10'
34+
]);
35+
expect(searcher.digitCapacity, 6);
36+
expect(searcher.seed, 15);
37+
expect(neighbours, hasLength(5));
38+
expect(neighbours.toString(),
39+
'((Index: 160, Distance: 4646.84383479798), (Index: 562, Distance: 5358.518638579137), (Index: 648, Distance: 5403.194564329513), (Index: 591, Distance: 5522.45033929686), (Index: 938, Distance: 6083.799141983568))');
40+
});
41+
42+
test('should restore from a JSON file, dtype=DType.float64', () async {
43+
final jsonSource = await File(
44+
'e2e/random_binary_projection_searcher/random_binary_projection_searcher_64_v1.json')
45+
.readAsString();
46+
final searcher = RandomBinaryProjectionSearcher.fromJson(jsonSource);
47+
final k = 5;
48+
final searchRadius = 3;
49+
final neighbours = searcher.query(
50+
Vector.fromList(
51+
[6.5, 3.01, 4.5, 1.5, 2.3, 22.3, 14.09, 2.9, 22.0, 11.22],
52+
dtype: DType.float64),
53+
k,
54+
searchRadius);
55+
56+
expect(searcher.columns, [
57+
'feature_1',
58+
'feature_2',
59+
'feature_3',
60+
'feature_4',
61+
'feature_5',
62+
'feature_6',
63+
'feature_7',
64+
'feature_8',
65+
'feature_9',
66+
'feature_10'
67+
]);
68+
expect(searcher.digitCapacity, 6);
69+
expect(searcher.seed, 15);
70+
expect(neighbours, hasLength(5));
71+
expect(neighbours.toString(),
72+
'((Index: 160, Distance: 4646.843905472263), (Index: 562, Distance: 5358.51853463377), (Index: 648, Distance: 5403.194350179125), (Index: 591, Distance: 5522.45040389316), (Index: 938, Distance: 6083.79924961192))');
73+
});
74+
});
75+
}

lib/ml_algo.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ export 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
1414
export 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart';
1515
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
1616
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
17+
export 'package:ml_algo/src/retrieval/random_binary_projection_searcher/random_binary_projection_searcher.dart';
1718
export 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';

lib/src/model_selection/split_indices_provider/lpo_indices_provider.dart

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
22

33
class LpoIndicesProvider implements SplitIndicesProvider {
4-
LpoIndicesProvider([this._p = 2]) {
5-
if (_p == 0) {
6-
throw UnsupportedError('Value `$_p` for parameter `p` is unsupported');
7-
}
8-
}
4+
LpoIndicesProvider([this._p = 2]);
95

106
final int _p;
117

lib/src/retrieval/kd_tree/kd_tree.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import 'package:ml_algo/src/common/serializable/serializable.dart';
22
import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree.dart';
33
import 'package:ml_algo/src/retrieval/kd_tree/helpers/create_kd_tree_from_iterable.dart';
44
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
5-
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart';
65
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
6+
import 'package:ml_algo/src/retrieval/neighbour.dart';
77
import 'package:ml_dataframe/ml_dataframe.dart';
88
import 'package:ml_linalg/distance.dart';
99
import 'package:ml_linalg/dtype.dart';
@@ -151,7 +151,7 @@ abstract class KDTree implements Serializable {
151151
/// print(neighbours[0].index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
152152
/// }
153153
/// ```
154-
Iterable<KDTreeNeighbour> query(Vector point, int k,
154+
Iterable<Neighbour> query(Vector point, int k,
155155
[Distance distance = Distance.euclidean]);
156156

157157
/// Returns [k] nearest neighbours for [point], [point] is [Iterable] unlike
@@ -179,6 +179,6 @@ abstract class KDTree implements Serializable {
179179
/// print(neighbours[0].index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
180180
/// }
181181
/// ```
182-
Iterable<KDTreeNeighbour> queryIterable(Iterable<num> point, int k,
182+
Iterable<Neighbour> queryIterable(Iterable<num> point, int k,
183183
[Distance distance = Distance.euclidean]);
184184
}

0 commit comments

Comments
 (0)