Skip to content

Commit c2e771e

Browse files
authored
Added LinearRegressor.SGD constructor (#231)
1 parent 8aa109f commit c2e771e

File tree

6 files changed

+219
-52
lines changed

6 files changed

+219
-52
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 16.14.0
4+
- `LinearRegressor.SGD` constructor added
5+
36
## 16.13.0
47
- `RandomBinaryProjectionSearcher`:
58
- Distance type considered

README.md

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ it in web applications.
4444
- [LogisticRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor-class.html).
4545
A class that performs linear binary classification of data. To use this kind of classifier your data has to be
4646
[linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
47+
48+
- [LogisticRegressor.SGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.SGD.html).
49+
Implementation of the logistic regression algorithm based on stochastic gradient descent with L2 regularisation.
50+
To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
4751

4852
- [SoftmaxRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/SoftmaxRegressor-class.html).
4953
A class that performs linear multiclass classification of data. To use this kind of classifier your data has to be
@@ -100,7 +104,7 @@ in your dependencies:
100104

101105
````
102106
dependencies:
103-
ml_dataframe: ^1.4.2
107+
ml_dataframe: ^1.5.0
104108
ml_preprocessing: ^7.0.2
105109
````
106110

@@ -125,11 +129,8 @@ We have 2 options here:
125129

126130
- Download the dataset from [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database).
127131

128-
- Or we may simply use [getPimaIndiansDiabetesDataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/getPimaIndiansDiabetesDataFrame.html) function
129-
from [ml_dataframe](https://pub.dev/packages/ml_dataframe) package. The function returns a ready to use [DataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/DataFrame-class.html) instance
130-
filled with `Pima Indians Diabetes Database` data.
131-
132-
If we chose the first option, we should do the following:
132+
<details>
133+
<summary>Instructions</summary>
133134

134135
#### For a desktop application:
135136

@@ -142,18 +143,7 @@ final samples = await fromCsv('datasets/pima_indians_diabetes_database.csv');
142143

143144
#### For a flutter application:
144145

145-
Be sure that you have ml_dataframe package version at least 1.0.0 and ml_algo package version at least 16.0.0
146-
in your pubspec.yaml:
147-
148-
````
149-
dependencies:
150-
...
151-
ml_algo: ^16.11.2
152-
ml_dataframe: ^1.4.2
153-
...
154-
````
155-
156-
Then it's needed to add the dataset to the flutter assets by adding the following config in the pubspec.yaml:
146+
It's needed to add the dataset to the flutter assets by adding the following config in the pubspec.yaml:
157147

158148
````
159149
flutter:
@@ -168,10 +158,30 @@ can access the dataset:
168158
import 'package:flutter/services.dart' show rootBundle;
169159
import 'package:ml_dataframe/ml_dataframe.dart';
170160
171-
final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv');
172-
final samples = DataFrame.fromRawCsv(rawCsvContent);
161+
void main() async {
162+
final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv');
163+
final samples = DataFrame.fromRawCsv(rawCsvContent);
164+
}
165+
```
166+
</details>
167+
168+
- Or we may simply use [getPimaIndiansDiabetesDataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/getPimaIndiansDiabetesDataFrame.html) function
169+
from [ml_dataframe](https://pub.dev/packages/ml_dataframe) package. The function returns a ready to use [DataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/DataFrame-class.html) instance
170+
filled with `Pima Indians Diabetes Database` data.
171+
172+
<details>
173+
<summary>Instructions</summary>
174+
175+
```dart
176+
import 'package:ml_dataframe/ml_dataframe.dart';
177+
178+
void main() {
179+
final samples = getPimaIndiansDiabetesDataFrame();
180+
}
173181
```
174182

183+
</details>
184+
175185
### Prepare datasets for training and testing
176186

177187
Data in this file is represented by 768 records and 8 features. The 9th column is a label column, it contains either 0 or 1
@@ -475,7 +485,7 @@ final targetName = 'col_13';
475485
then let's shuffle the data:
476486

477487
```dart
478-
samples.shuffle();
488+
final shuffledSamples = samples.shuffle();
479489
```
480490

481491
Now it's the time to prepare data splits. Let's split the data into train and test subsets using the library's [splitData](https://github.com/gyrdym/ml_algo/blob/master/lib/src/model_selection/split_data.dart)
@@ -501,7 +511,7 @@ e.g. stochastic gradient descent algorithm:
501511

502512
```dart
503513
final model = LinearRegressor.SGD(
504-
samples
514+
shuffledSamples
505515
targetName,
506516
iterationLimit: 90,
507517
);
@@ -511,7 +521,7 @@ or linear regression based on coordinate descent with Lasso regularization:
511521

512522
```dart
513523
final model = LinearRegressor.lasso(
514-
samples
524+
shuffledSamples,
515525
targetName,
516526
iterationLimit: 90,
517527
);
@@ -538,14 +548,16 @@ import 'dart:io';
538548
import 'package:ml_algo/ml_algo.dart';
539549
import 'package:ml_dataframe/ml_dataframe.dart';
540550
541-
final file = File('housing_model.json');
542-
final encodedModel = await file.readAsString();
543-
final model = LinearRegressor.fromJson(encodedModel);
544-
final unlabelledData = await fromCsv('some_unlabelled_data.csv');
545-
final prediction = model.predict(unlabelledData);
546-
547-
print(prediction.header);
548-
print(prediction.rows);
551+
void main() async {
552+
final file = File('housing_model.json');
553+
final encodedModel = await file.readAsString();
554+
final model = LinearRegressor.fromJson(encodedModel);
555+
final unlabelledData = await fromCsv('some_unlabelled_data.csv');
556+
final prediction = model.predict(unlabelledData);
557+
558+
print(prediction.header);
559+
print(prediction.rows);
560+
}
549561
```
550562

551563
<details>
@@ -556,8 +568,7 @@ import 'package:ml_algo/ml_algo.dart';
556568
import 'package:ml_dataframe/ml_dataframe.dart';
557569
558570
void main() async {
559-
final samples = (await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' '))
560-
..shuffle();
571+
final samples = (await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' ')).shuffle();
561572
final targetName = 'col_13';
562573
final splits = splitData(samples, [0.8]);
563574
final trainData = splits[0];
@@ -582,8 +593,7 @@ import 'package:ml_dataframe/ml_dataframe.dart';
582593
583594
void main() async {
584595
final rawCsvContent = await rootBundle.loadString('assets/datasets/housing.csv');
585-
final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ')
586-
..shuffle();
596+
final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ').shuffle();
587597
final targetName = 'col_13';
588598
final splits = splitData(samples, [0.8]);
589599
final trainData = splits[0];

e2e/logistic_regressor/logistic_regressor_test.dart renamed to e2e/logistic_regressor/logistic_regressor_sgd_test.dart

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,19 @@ import 'package:ml_linalg/vector.dart';
44
import 'package:test/test.dart';
55

66
Future<Vector> evaluateLogisticRegressor(MetricType metric, DType dtype) {
7-
final samples = getPimaIndiansDiabetesDataFrame().shuffle();
7+
final samples = getPimaIndiansDiabetesDataFrame().shuffle(seed: 12);
88
final numberOfFolds = 5;
9-
final targetNames = ['Outcome'];
109
final validator = CrossValidator.kFold(
1110
samples,
1211
numberOfFolds: numberOfFolds,
1312
);
14-
final createClassifier = (DataFrame trainSamples) => LogisticRegressor(
13+
final createClassifier = (DataFrame trainSamples) => LogisticRegressor.SGD(
1514
trainSamples,
16-
targetNames.first,
17-
optimizerType: LinearOptimizerType.gradient,
18-
iterationsLimit: 100,
19-
learningRateType: LearningRateType.exponential,
20-
batchSize: trainSamples.rows.length,
21-
probabilityThreshold: 0.5,
15+
'Outcome',
16+
seed: 10,
17+
iterationsLimit: 50,
18+
initialLearningRate: 1e-4,
19+
learningRateType: LearningRateType.constant,
2220
dtype: dtype,
2321
);
2422

@@ -29,7 +27,7 @@ Future<Vector> evaluateLogisticRegressor(MetricType metric, DType dtype) {
2927
}
3028

3129
Future main() async {
32-
group('LogisticRegressor', () {
30+
group('LogisticRegressor.SGD', () {
3331
test(
3432
'should return adequate score on pima indians diabetes dataset using '
3533
'accuracy metric, dtype=DType.float32', () async {

lib/src/classifier/classifier.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ abstract class Classifier extends Predictor {
3131
/// 908 | 404 | 503 | -100 | 100 | -100
3232
///
3333
/// If a prediction algorithm meets 100 in a target column, it will
34-
/// interpret the value as a positive outcome for the appropriate class
34+
/// interpret the value as a positive outcome for the corresponding class
3535
num get positiveLabel;
3636

3737
/// A value using to encode negative class.
@@ -57,6 +57,6 @@ abstract class Classifier extends Predictor {
5757
/// 908 | 404 | 503 | -100 | 100 | -100
5858
///
5959
/// If a prediction algorithm meets -100 in a target column, it will
60-
/// interpret the value as a negative outcome for the appropriate class
60+
/// interpret the value as a negative outcome for the corresponding class
6161
num get negativeLabel;
6262
}

0 commit comments

Comments
 (0)