Skip to content

Commit cdfe598

Browse files
authored
README: LogisticRegressor example corrected (#233)
1 parent 2d2b38d commit cdfe598

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 16.15.1
4+
- README: LogisticRegressor example corrected
5+
36
## 16.15.0
47
- `LinearRegressor.BGD` constructor added
58

README.md

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -225,34 +225,30 @@ if the selected hyperparameters are good enough or not:
225225

226226
```dart
227227
final createClassifier = (DataFrame samples) =>
228-
LogisticRegressor(
228+
// BGD stands for "Batch Gradient Descent" that's meaning that the classifier will use the whole dataset on every
229+
// training iteration
230+
LogisticRegressor.BGD(
229231
samples
230232
targetColumnName,
231-
optimizerType: LinearOptimizerType.gradient,
232233
iterationsLimit: 90,
233234
learningRateType: LearningRateType.timeBased,
234-
batchSize: samples.rows.length,
235235
probabilityThreshold: 0.7,
236236
);
237237
```
238238

239239
Let's describe our hyperparameters:
240-
- `optimizerType` - a type of optimization algorithm that will be used to learn coefficients of our model, this time we
241-
decided to use a vanilla gradient ascent algorithm
242240
- `iterationsLimit` - number of learning iterations. The selected optimization algorithm (gradient ascent in our case) will
243241
be cyclically run this amount of times
244242
- `learningRateType` - a strategy for learning rate update. In our case, the learning rate will decrease after every
245243
iteration
246-
- `batchSize` - the size of data (in rows) that will be used per each iteration. As we have a really small dataset we may use
247-
full-batch gradient ascent, that's why we used `samples.rows.length` here - the total amount of data.
248244
- `probabilityThreshold` - lower bound for positive label probability
249245

250246
If we want to evaluate the learning process more thoroughly, we may pass `collectLearningData` argument to the classifier
251247
constructor:
252248

253249
```dart
254250
final createClassifier = (DataFrame samples) =>
255-
LogisticRegressor(
251+
LogisticRegressor.BGD(
256252
...,
257253
collectLearningData: true,
258254
);
@@ -323,22 +319,26 @@ After that we can simply read the model from the file and make predictions:
323319
```dart
324320
import 'dart:io';
325321
326-
final fileName = 'diabetes_classifier.json';
327-
final file = File(fileName);
328-
final encodedModel = await file.readAsString();
329-
final classifier = LogisticRegressor.fromJson(encodedModel);
330-
final unlabelledData = await fromCsv('some_unlabelled_data.csv');
331-
final prediction = classifier.predict(unlabelledData);
322+
void main() {
323+
// ...
324+
final fileName = 'diabetes_classifier.json';
325+
final file = File(fileName);
326+
final encodedModel = await file.readAsString();
327+
final classifier = LogisticRegressor.fromJson(encodedModel);
328+
final unlabelledData = await fromCsv('some_unlabelled_data.csv');
329+
final prediction = classifier.predict(unlabelledData);
332330
333-
print(prediction.header); // ('class variable (0 or 1)')
334-
print(prediction.rows); // [
331+
print(prediction.header); // ('class variable (0 or 1)')
332+
print(prediction.rows); // [
335333
// (1),
336334
// (0),
337335
// (0),
338336
// (1),
339337
// ...,
340338
// (1),
341339
// ]
340+
// ...
341+
}
342342
```
343343

344344
Please note that all the hyperparameters that we used to generate the model are persisted as the model's read-only
@@ -368,13 +368,11 @@ void main() async {
368368
final testData = splits[1];
369369
final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
370370
final createClassifier = (DataFrame samples) =>
371-
LogisticRegressor(
371+
LogisticRegressor.BGD(
372372
samples
373373
targetColumnName,
374-
optimizerType: LinearOptimizerType.gradient,
375374
iterationsLimit: 90,
376375
learningRateType: LearningRateType.timeBased,
377-
batchSize: samples.rows.length,
378376
probabilityThreshold: 0.7,
379377
);
380378
final scores = await validator.evaluate(createClassifier, MetricType.accuracy);
@@ -413,13 +411,11 @@ void main() async {
413411
final testData = splits[1];
414412
final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
415413
final createClassifier = (DataFrame samples) =>
416-
LogisticRegressor(
414+
LogisticRegressor.BGD(
417415
samples
418416
targetColumnName,
419-
optimizerType: LinearOptimizerType.gradient,
420417
iterationsLimit: 90,
421418
learningRateType: LearningRateType.timeBased,
422-
batchSize: samples.rows.length,
423419
probabilityThreshold: 0.7,
424420
);
425421
final scores = await validator.evaluate(createClassifier, MetricType.accuracy);

e2e/logistic_regressor/logistic_regressor_bgd_test.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import 'package:ml_linalg/vector.dart';
44
import 'package:test/test.dart';
55

66
Future<Vector> evaluateLogisticRegressor(MetricType metric, DType dtype) {
7-
final samples = getPimaIndiansDiabetesDataFrame().shuffle(seed: 12);
7+
final samples = getPimaIndiansDiabetesDataFrame().shuffle();
88
final numberOfFolds = 5;
99
final validator = CrossValidator.kFold(
1010
samples,
@@ -14,8 +14,8 @@ Future<Vector> evaluateLogisticRegressor(MetricType metric, DType dtype) {
1414
trainSamples,
1515
'Outcome',
1616
iterationsLimit: 50,
17-
initialLearningRate: 1e-4,
18-
learningRateType: LearningRateType.constant,
17+
decay: .1,
18+
learningRateType: LearningRateType.timeBased,
1919
dtype: dtype,
2020
);
2121

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.15.0
3+
version: 16.15.1
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

0 commit comments

Comments
 (0)