Skip to content

Commit 56bc063

Browse files
authored
DecisionTreeClassifier example added to README (#214)
1 parent b5890e6 commit 56bc063

32 files changed

+264
-207
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## 16.7.2
4+
- DecisionTreeClassifier:
5+
- TreeNode fields renamed
6+
- Added example of DecisionTreeClassifier usage to `README.md`
7+
38
## 16.7.1
49
- DecisionTreeClassifier:
510
- Fixed greedy splitter in case of a split column consisting of the same values

README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The library is a part of ecosystem:
1919
- [Examples](#examples)
2020
- [Logistic regression](#logistic-regression)
2121
- [Linear regression](#linear-regression)
22+
- [Decision tree based classification](#decision-tree-based-classification)
2223
- [Models retraining](#models-retraining)
2324
- [Notes on gradient based optimisation algorithms](#a-couple-of-words-about-linear-models-which-use-gradient-optimisation-methods)
2425

@@ -579,6 +580,75 @@ void main() async {
579580
````
580581
</details>
581582

583+
## Decision tree based classification
584+
585+
Let's try to classify data from a well-known [Iris](https://www.kaggle.com/datasets/uciml/iris) dataset using a non-linear algorithm - [decision trees](https://en.wikipedia.org/wiki/Decision_tree)
586+
587+
First, you need to download the data and place it in a proper place in your file system. To do so you should follow the
588+
instructions which are given in [Logistic regression](#logistic-regression) section.
589+
590+
After loading the data, it's needed to preprocess it. We should drop `Id` column since the column doesn't make sense.
591+
Also, we need to encode 'Species' column - originally, it contains 3 repeated string labels, to feed it to the classifier
592+
it's needed to convert the labels into numbers:
593+
594+
```dart
595+
import 'package:ml_algo/ml_algo.dart';
596+
import 'package:ml_dataframe/ml_dataframe.dart';
597+
import 'package:ml_preprocessing/ml_preprocessing.dart';
598+
599+
void main() async {
600+
final samples = (await fromCsv('path/to/iris/dataset.csv'))
601+
.shuffle()
602+
.dropSeries(seriesNames: ['Id']);
603+
604+
final pipeline = Pipeline(samples, [
605+
encodeAsIntegerLabels(
606+
featureNames: ['Species'], // Here we convert strings from 'Species' column into numbers
607+
),
608+
]);
609+
}
610+
```
611+
612+
Next, let's create a model:
613+
614+
```dart
615+
final model = DecisionTreeClassifier(
616+
processed,
617+
'Species',
618+
minError: 0.3,
619+
minSamplesCount: 5,
620+
maxDepth: 4,
621+
);
622+
```
623+
624+
As you can see, we specified 3 hyperparameters: `minError`, `minSamplesCount` and `maxDepth`. Let's look at the
625+
parameters in more detail:
626+
627+
- `minError`. A minimum error on a tree node. If the error is less than or equal to the value, the node is considered a leaf.
628+
- `minSamplesCount`. A minimum number of samples on a node. If the number of samples is less than or equal to the value, the node is considered a leaf.
629+
- `maxDepth`. A maximum depth of the resulting decision tree. Once the tree reaches the `maxDepth`, all the level's nodes are considered leaves.
630+
631+
All the parameters serve as stopping criteria for the tree building algorithm.
632+
633+
Now we have a ready to use model. As usual, we can save the model to a JSON-file:
634+
635+
```dart
636+
await model.saveAsJson('path/to/json/file.json');
637+
```
638+
639+
Unlike other models, in case of decision tree we can visualise the algorithm result - we can save the model as SVG-file:
640+
641+
```dart
642+
await model.saveAsSvg('path/to/svg/file.svg');
643+
```
644+
645+
Once we saved it, we can open the file through any image viewer, e.g. through a web-browser. An example of the
646+
resulting svg-image:
647+
648+
<p align="center">
649+
<img height="600" src="https://raw.github.com/gyrdym/ml_algo/master/e2e/decision_tree_classifier/iris_tree.svg?sanitize=true">
650+
</p>
651+
582652
## Models retraining
583653

584654
Someday our previously shining model can degrade in terms of prediction accuracy - in this case we can retrain it.
Lines changed: 1 addition & 1 deletion
Loading

0 commit comments

Comments
 (0)