@@ -225,34 +225,30 @@ if the selected hyperparameters are good enough or not:
225225
226226``` dart
227227final createClassifier = (DataFrame samples) =>
228- LogisticRegressor(
228+ // BGD stands for "Batch Gradient Descent" that's meaning that the classifier will use the whole dataset on every
229+ // training iteration
230+ LogisticRegressor.BGD(
229231 samples
230232 targetColumnName,
231- optimizerType: LinearOptimizerType.gradient,
232233 iterationsLimit: 90,
233234 learningRateType: LearningRateType.timeBased,
234- batchSize: samples.rows.length,
235235 probabilityThreshold: 0.7,
236236 );
237237```
238238
239239Let's describe our hyperparameters:
240- - ` optimizerType ` - a type of optimization algorithm that will be used to learn coefficients of our model, this time we
241- decided to use a vanilla gradient ascent algorithm
242240- ` iterationsLimit ` - number of learning iterations. The selected optimization algorithm (gradient ascent in our case) will
243241be cyclically run this amount of times
244242- ` learningRateType ` - a strategy for learning rate update. In our case, the learning rate will decrease after every
245243iteration
246- - ` batchSize ` - the size of data (in rows) that will be used per each iteration. As we have a really small dataset we may use
247- full-batch gradient ascent, that's why we used ` samples.rows.length ` here - the total amount of data.
248244- ` probabilityThreshold ` - lower bound for positive label probability
249245
250246If we want to evaluate the learning process more thoroughly, we may pass ` collectLearningData ` argument to the classifier
251247constructor:
252248
253249``` dart
254250final createClassifier = (DataFrame samples) =>
255- LogisticRegressor(
251+ LogisticRegressor.BGD (
256252 ...,
257253 collectLearningData: true,
258254 );
@@ -323,22 +319,26 @@ After that we can simply read the model from the file and make predictions:
323319``` dart
324320import 'dart:io';
325321
326- final fileName = 'diabetes_classifier.json';
327- final file = File(fileName);
328- final encodedModel = await file.readAsString();
329- final classifier = LogisticRegressor.fromJson(encodedModel);
330- final unlabelledData = await fromCsv('some_unlabelled_data.csv');
331- final prediction = classifier.predict(unlabelledData);
322+ void main() {
323+ // ...
324+ final fileName = 'diabetes_classifier.json';
325+ final file = File(fileName);
326+ final encodedModel = await file.readAsString();
327+ final classifier = LogisticRegressor.fromJson(encodedModel);
328+ final unlabelledData = await fromCsv('some_unlabelled_data.csv');
329+ final prediction = classifier.predict(unlabelledData);
332330
333- print(prediction.header); // ('class variable (0 or 1)')
334- print(prediction.rows); // [
331+ print(prediction.header); // ('class variable (0 or 1)')
332+ print(prediction.rows); // [
335333 // (1),
336334 // (0),
337335 // (0),
338336 // (1),
339337 // ...,
340338 // (1),
341339 // ]
340+ // ...
341+ }
342342```
343343
344344Please note that all the hyperparameters that we used to generate the model are persisted as the model's read-only
@@ -368,13 +368,11 @@ void main() async {
368368 final testData = splits[1];
369369 final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
370370 final createClassifier = (DataFrame samples) =>
371- LogisticRegressor(
371+ LogisticRegressor.BGD (
372372 samples
373373 targetColumnName,
374- optimizerType: LinearOptimizerType.gradient,
375374 iterationsLimit: 90,
376375 learningRateType: LearningRateType.timeBased,
377- batchSize: samples.rows.length,
378376 probabilityThreshold: 0.7,
379377 );
380378 final scores = await validator.evaluate(createClassifier, MetricType.accuracy);
@@ -413,13 +411,11 @@ void main() async {
413411 final testData = splits[1];
414412 final validator = CrossValidator.kFold(validationData, numberOfFolds: 5);
415413 final createClassifier = (DataFrame samples) =>
416- LogisticRegressor(
414+ LogisticRegressor.BGD (
417415 samples
418416 targetColumnName,
419- optimizerType: LinearOptimizerType.gradient,
420417 iterationsLimit: 90,
421418 learningRateType: LearningRateType.timeBased,
422- batchSize: samples.rows.length,
423419 probabilityThreshold: 0.7,
424420 );
425421 final scores = await validator.evaluate(createClassifier, MetricType.accuracy);
0 commit comments