@@ -467,20 +467,26 @@ Let's train the model:
467467final model = LinearRegressor(trainData, targetName);
468468```
469469
470- By default, ` LinearRegressor ` uses closed-form solution to train the model. It's possible to use a different solution type,
471- e.g. one can use gradient-based algorithm:
470+ By default, ` LinearRegressor ` uses closed-form solution to train the model. One can also use a different solution type,
471+ e.g. stochastic gradient descent algorithm:
472472
473473``` dart
474- final model = LinearRegressor(
474+ final model = LinearRegressor.SGD (
475475 samples
476476 targetName,
477- optimizerType: LinearOptimizerType.gradient,
478- iterationsLimit: 90,
479- learningRateType: LearningRateType.timeBased,
477+ iterationLimit: 90,
480478);
481479```
482480
483- As you may noticed, we have to provide a bunch of hyperparameters in case of gradient-based regression.
481+ or linear regression based on coordinate descent with Lasso regularization:
482+
483+ ``` dart
484+ final model = LinearRegressor.lasso(
485+ samples
486+ targetName,
487+ iterationLimit: 90,
488+ );
489+ ```
484490
485491Next, we should evaluate performance of our model:
486492
@@ -521,7 +527,8 @@ import 'package:ml_algo/ml_algo.dart';
521527import 'package:ml_dataframe/ml_dataframe.dart';
522528
523529void main() async {
524- final samples = await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' ');
530+ final samples = (await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' '))
531+ ..shuffle();
525532 final targetName = 'col_13';
526533 final splits = splitData(samples, [0.8]);
527534 final trainData = splits[0];
@@ -546,7 +553,8 @@ import 'package:ml_dataframe/ml_dataframe.dart';
546553
547554void main() async {
548555 final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv');
549- final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ');
556+ final samples = DataFrame.fromRawCsv(rawCsvContent, fieldDelimiter: ' ')
557+ ..shuffle();
550558 final targetName = 'col_13';
551559 final splits = splitData(samples, [0.8]);
552560 final trainData = splits[0];
0 commit comments