Skip to content

Commit 6d1e127

Browse files
authored
Added linear regression examples to README.md (#205)
1 parent 06032ae commit 6d1e127

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 16.5.1
4+
- Added linear regression examples to `README.md`
5+
36
## 16.5.0
47
- LinearRegressor:
58
- Added `LinearRegressor.SGD` constructor

README.md

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -467,20 +467,26 @@ Let's train the model:
467467
final model = LinearRegressor(trainData, targetName);
468468
```
469469

470-
By default, `LinearRegressor` uses closed-form solution to train the model. It's possible to use a different solution type,
471-
e.g. one can use gradient-based algorithm:
470+
By default, `LinearRegressor` uses closed-form solution to train the model. One can also use a different solution type,
471+
e.g. stochastic gradient descent algorithm:
472472

473473
```dart
474-
final model = LinearRegressor(
474+
final model = LinearRegressor.SGD(
475475
samples
476476
targetName,
477-
optimizerType: LinearOptimizerType.gradient,
478-
iterationsLimit: 90,
479-
learningRateType: LearningRateType.timeBased,
477+
iterationLimit: 90,
480478
);
481479
```
482480

483-
As you may noticed, we have to provide a bunch of hyperparameters in case of gradient-based regression.
481+
or linear regression based on coordinate descent with Lasso regularization:
482+
483+
```dart
484+
final model = LinearRegressor.lasso(
485+
samples
486+
targetName,
487+
iterationLimit: 90,
488+
);
489+
```
484490

485491
Next, we should evaluate performance of our model:
486492

@@ -521,7 +527,8 @@ import 'package:ml_algo/ml_algo.dart';
521527
import 'package:ml_dataframe/ml_dataframe.dart';
522528
523529
void main() async {
524-
final samples = await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' ');
530+
final samples = (await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' '))
531+
..shuffle();
525532
final targetName = 'col_13';
526533
final splits = splitData(samples, [0.8]);
527534
final trainData = splits[0];
@@ -546,7 +553,8 @@ import 'package:ml_dataframe/ml_dataframe.dart';
546553
547554
void main() async {
548555
final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv');
549-
final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ');
556+
final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ')
557+
..shuffle();
550558
final targetName = 'col_13';
551559
final splits = splitData(samples, [0.8]);
552560
final trainData = splits[0];

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.5.0
3+
version: 16.5.1
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

0 commit comments

Comments
 (0)