Skip to content

Commit 0de05a2

Browse files
authored
Logistic Regression: newton method added (#239)
1 parent dc2c80a commit 0de05a2

File tree

19 files changed

+269
-83
lines changed

19 files changed

+269
-83
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# Changelog
22

3+
## 16.17.0
4+
- LogisticRegressor:
5+
- Newton method added
6+
37
## 16.16.0
48
- LinearRegressor:
5-
- Newton method added
9+
- Newton method added
610

711
## 16.15.2
812
- LinearRegressor, LogisticRegressor, SoftmaxRegressor:

README.md

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ it in web applications.
5151

5252
- [LogisticRegressor.BGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.BGD.html).
5353
Implementation of the logistic regression algorithm based on batch gradient descent with L2 regularisation.
54+
To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
55+
56+
- [LogisticRegressor.newton](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.newton.html).
57+
Implementation of the logistic regression algorithm based on Newton-Raphson method with L2 regularisation.
5458
To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability).
5559

5660
- [SoftmaxRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/SoftmaxRegressor-class.html).
@@ -78,7 +82,7 @@ it in web applications.
7882
Implementation of the linear regression algorithm based on batch gradient descent with L2 regularisation
7983

8084
- [LinearRegressor.newton](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.newton.html)
81-
Implementation of the linear regression algorithm based on Newton method with L2 regularisation
85+
Implementation of the linear regression algorithm based on Newton-Raphson method with L2 regularisation
8286

8387
- [KnnRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/KnnRegressor-class.html)
8488
A class that makes predictions for each new observation based on the first `k` closest observations from
@@ -231,30 +235,18 @@ if the selected hyperparameters are good enough or not:
231235

232236
```dart
233237
final createClassifier = (DataFrame samples) =>
234-
// BGD stands for "Batch Gradient Descent" that's meaning that the classifier will use the whole dataset on every
235-
// training iteration
236-
LogisticRegressor.BGD(
238+
LogisticRegressor(
237239
samples
238240
targetColumnName,
239-
iterationsLimit: 90,
240-
learningRateType: LearningRateType.timeBased,
241-
probabilityThreshold: 0.7,
242241
);
243242
```
244243

245-
Let's describe our hyperparameters:
246-
- `iterationsLimit` - number of learning iterations. The selected optimization algorithm (gradient ascent in our case) will
247-
be cyclically run this amount of times
248-
- `learningRateType` - a strategy for learning rate update. In our case, the learning rate will decrease after every
249-
iteration
250-
- `probabilityThreshold` - lower bound for positive label probability
251-
252244
If we want to evaluate the learning process more thoroughly, we may pass `collectLearningData` argument to the classifier
253245
constructor:
254246

255247
```dart
256248
final createClassifier = (DataFrame samples) =>
257-
LogisticRegressor.BGD(
249+
LogisticRegressor(
258250
...,
259251
collectLearningData: true,
260252
);
@@ -265,7 +257,7 @@ the model creation.
265257

266258
### Evaluate the performance of the model
267259

268-
Assume, we chose really good hyperparameters. In order to validate this hypothesis let's use CrossValidator instance
260+
Assume, we chose perfect hyperparameters. In order to validate this hypothesis, let's use CrossValidator instance
269261
created before:
270262

271263
````dart
@@ -287,7 +279,7 @@ print('accuracy on k fold validation: ${accuracy.toStringAsFixed(2)}');
287279
We can see something like this:
288280

289281
````
290-
accuracy on k fold validation: 0.65
282+
accuracy on k fold validation: 0.75
291283
````
292284

293285
Let's assess our hyperparameters on the test set in order to evaluate the model's generalization error:
@@ -374,12 +366,9 @@ void main() async {
374366
final testData = splits[1];
375367
final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
376368
final createClassifier = (DataFrame samples) =>
377-
LogisticRegressor.BGD(
369+
LogisticRegressor(
378370
samples
379371
targetColumnName,
380-
iterationsLimit: 90,
381-
learningRateType: LearningRateType.timeBased,
382-
probabilityThreshold: 0.7,
383372
);
384373
final scores = await validator.evaluate(createClassifier, MetricType.accuracy);
385374
final accuracy = scores.mean();
@@ -417,12 +406,9 @@ void main() async {
417406
final testData = splits[1];
418407
final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
419408
final createClassifier = (DataFrame samples) =>
420-
LogisticRegressor.BGD(
409+
LogisticRegressor(
421410
samples
422411
targetColumnName,
423-
iterationsLimit: 90,
424-
learningRateType: LearningRateType.timeBased,
425-
probabilityThreshold: 0.7,
426412
);
427413
final scores = await validator.evaluate(createClassifier, MetricType.accuracy);
428414
final accuracy = scores.mean();

analysis_options.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ analyzer:
1111
dead_code: error
1212
duplicate_import: error
1313
unused_import: error
14-
deprecated_member_use_from_same_package: warning
Lines changed: 1 addition & 1 deletion
Loading

e2e/decision_tree_classifier/pima_indians_tree.svg

Lines changed: 1 addition & 1 deletion
Loading

e2e/logistic_regressor/logistic_regressor_bgd_test.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Future main() async {
4040
'should return adequate score on pima indians diabetes dataset using '
4141
'accuracy metric, dtype=DType.float64', () async {
4242
final scores =
43-
await evaluateLogisticRegressor(MetricType.accuracy, DType.float32);
43+
await evaluateLogisticRegressor(MetricType.accuracy, DType.float64);
4444

4545
expect(scores.mean(), greaterThan(0.5));
4646
});
@@ -58,7 +58,7 @@ Future main() async {
5858
'should return adequate score on pima indians diabetes dataset using '
5959
'precision metric, dtype=DType.float64', () async {
6060
final scores =
61-
await evaluateLogisticRegressor(MetricType.precision, DType.float32);
61+
await evaluateLogisticRegressor(MetricType.precision, DType.float64);
6262

6363
expect(scores.mean(), greaterThan(0.5));
6464
});
@@ -76,7 +76,7 @@ Future main() async {
7676
'should return adequate score on pima indians diabetes dataset using '
7777
'recall metric, dtype=DType.float64', () async {
7878
final scores =
79-
await evaluateLogisticRegressor(MetricType.recall, DType.float32);
79+
await evaluateLogisticRegressor(MetricType.recall, DType.float64);
8080

8181
expect(scores.mean(), greaterThan(0.5));
8282
});
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import 'package:ml_algo/ml_algo.dart';
2+
import 'package:ml_dataframe/ml_dataframe.dart';
3+
import 'package:test/test.dart';
4+
5+
num evaluateLogisticRegressor(MetricType metric, DType dtype) {
6+
final data = getPimaIndiansDiabetesDataFrame().shuffle();
7+
final samples = splitData(data, [0.8]);
8+
final trainSamples = samples.first;
9+
final testSamples = samples.last;
10+
final model = LogisticRegressor.newton(
11+
trainSamples,
12+
'Outcome',
13+
dtype: dtype,
14+
);
15+
16+
return model.assess(testSamples, metric);
17+
}
18+
19+
Future main() async {
20+
group('LogisticRegressor.newton', () {
21+
test(
22+
'should return adequate score on pima indians diabetes dataset using '
23+
'accuracy metric, dtype=DType.float32', () {
24+
final score =
25+
evaluateLogisticRegressor(MetricType.accuracy, DType.float32);
26+
27+
print('float32, accuracy is $score');
28+
29+
expect(score, greaterThan(0.7));
30+
});
31+
32+
test(
33+
'should return adequate score on pima indians diabetes dataset using '
34+
'accuracy metric, dtype=DType.float64', () {
35+
final score =
36+
evaluateLogisticRegressor(MetricType.accuracy, DType.float64);
37+
38+
print('float64, accuracy is $score');
39+
40+
expect(score, greaterThan(0.7));
41+
});
42+
43+
test(
44+
'should return adequate score on pima indians diabetes dataset using '
45+
'precision metric, dtype=DType.float32', () {
46+
final score =
47+
evaluateLogisticRegressor(MetricType.precision, DType.float32);
48+
49+
print('float32, precision is $score');
50+
51+
expect(score, greaterThan(0.65));
52+
});
53+
54+
test(
55+
'should return adequate score on pima indians diabetes dataset using '
56+
'precision metric, dtype=DType.float64', () {
57+
final score =
58+
evaluateLogisticRegressor(MetricType.precision, DType.float64);
59+
60+
print('float64, precision is $score');
61+
62+
expect(score, greaterThan(0.65));
63+
});
64+
65+
test(
66+
'should return adequate score on pima indians diabetes dataset using '
67+
'recall metric, dtype=DType.float32', () {
68+
final score = evaluateLogisticRegressor(MetricType.recall, DType.float32);
69+
70+
print('float32, recall is $score');
71+
72+
expect(score, greaterThan(0.65));
73+
});
74+
75+
test(
76+
'should return adequate score on pima indians diabetes dataset using '
77+
'recall metric, dtype=DType.float64', () {
78+
final score = evaluateLogisticRegressor(MetricType.recall, DType.float64);
79+
80+
print('float64, recall is $score');
81+
82+
expect(score, greaterThan(0.65));
83+
});
84+
});
85+
}

e2e/logistic_regressor/logistic_regressor_sgd_test.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Future main() async {
4141
'should return adequate score on pima indians diabetes dataset using '
4242
'accuracy metric, dtype=DType.float64', () async {
4343
final scores =
44-
await evaluateLogisticRegressor(MetricType.accuracy, DType.float32);
44+
await evaluateLogisticRegressor(MetricType.accuracy, DType.float64);
4545

4646
expect(scores.mean(), greaterThan(0.5));
4747
});
@@ -59,7 +59,7 @@ Future main() async {
5959
'should return adequate score on pima indians diabetes dataset using '
6060
'precision metric, dtype=DType.float64', () async {
6161
final scores =
62-
await evaluateLogisticRegressor(MetricType.precision, DType.float32);
62+
await evaluateLogisticRegressor(MetricType.precision, DType.float64);
6363

6464
expect(scores.mean(), greaterThan(0.5));
6565
});
@@ -77,7 +77,7 @@ Future main() async {
7777
'should return adequate score on pima indians diabetes dataset using '
7878
'recall metric, dtype=DType.float64', () async {
7979
final scores =
80-
await evaluateLogisticRegressor(MetricType.recall, DType.float32);
80+
await evaluateLogisticRegressor(MetricType.recall, DType.float64);
8181

8282
expect(scores.mean(), greaterThan(0.5));
8383
});

example/logistic_regression.dart

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ void main() {
77
final model = LogisticRegressor(
88
splits.first,
99
'Outcome',
10-
batchSize: splits.first.rows.length,
11-
learningRateType: LearningRateType.exponential,
12-
decay: 0.7,
13-
collectLearningData: true,
1410
);
1511

1612
print('ACURACY:');
@@ -22,8 +18,10 @@ void main() {
2218
print('PRECISION:');
2319
print(model.assess(splits.last, MetricType.precision));
2420

25-
print('LD: ');
26-
print(splits.last['Outcome'].data.take(10));
21+
print('Results (first row - actual values, second row - predicted values):');
22+
print(splits.last['Outcome'].data
23+
.take(10)
24+
.map((val) => num.parse(val.toString()).toDouble()));
2725
print(model
2826
.predict(splits.last.dropSeries(names: ['Outcome']))
2927
.series

lib/src/classifier/_constants/supported_linear_optimizer_types.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart';
22

33
const supportedLinearOptimizerTypes = [
44
LinearOptimizerType.gradient,
5+
LinearOptimizerType.newton,
56
];

0 commit comments

Comments
 (0)