Skip to content

Commit 4436b17

Browse files
authored
Decision Tree: save tree as svg image (#212)
1 parent a2fe388 commit 4436b17

19 files changed

+380
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.7.0
4+
- DecisionTreeClassifier:
5+
- Added `saveAsSvg` method which returns '.svg' file with a graphical representation of a tree
6+
37
## 16.6.3
48
- KDTree:
59
- `fromIterable` constructor added

benchmark/knn_regressor.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// 10.0 sec (MacBook Air mid 2017)
1+
// 6.0 sec (MacBook Air mid 2017)
22
import 'package:benchmark_harness/benchmark_harness.dart';
33
import 'package:ml_algo/ml_algo.dart';
44
import 'package:ml_dataframe/ml_dataframe.dart';
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import 'package:ml_algo/ml_algo.dart';
2+
import 'package:ml_dataframe/ml_dataframe.dart';
3+
import 'package:ml_preprocessing/ml_preprocessing.dart';
4+
import 'package:test/test.dart';
5+
6+
void main() async {
7+
group('DecisionTreeClassifier', () {
8+
test('should save graphical representation as svg image, iris dataset',
9+
() async {
10+
final samples = (await fromCsv('e2e/_datasets/iris.csv'))
11+
.shuffle()
12+
.dropSeries(seriesNames: ['Id']);
13+
final pipeline = Pipeline(samples, [
14+
encodeAsIntegerLabels(
15+
featureNames: ['Species'],
16+
),
17+
]);
18+
final processed = pipeline.process(samples);
19+
final classifier = DecisionTreeClassifier(
20+
processed,
21+
'Species',
22+
minError: 0.3,
23+
minSamplesCount: 5,
24+
maxDepth: 4,
25+
);
26+
27+
await classifier.saveAsSvg('e2e/decision_tree_classifier/iris_tree.svg');
28+
});
29+
30+
test(
31+
'should save graphical representation as svg image, pima indians diabetes dataset',
32+
() async {
33+
final samples =
34+
(await fromCsv('e2e/_datasets/pima_indians_diabetes_database.csv'))
35+
.shuffle();
36+
final classifier = DecisionTreeClassifier(
37+
samples,
38+
'class variable (0 or 1)',
39+
minError: 0.1,
40+
minSamplesCount: 2,
41+
maxDepth: 4,
42+
);
43+
44+
await classifier
45+
.saveAsSvg('e2e/decision_tree_classifier/pima_indians_tree.svg');
46+
});
47+
});
48+
}
Lines changed: 1 addition & 0 deletions
Loading
Lines changed: 1 addition & 0 deletions
Loading

lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import 'dart:io';
2+
13
import 'package:ml_algo/src/classifier/classifier.dart';
24
import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart';
35
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
@@ -122,4 +124,22 @@ abstract class DecisionTreeClassifier
122124
///
123125
/// The value is read-only, it's a hyperparameter of the model
124126
int get maxDepth;
127+
128+
/// Saves tree as SVG-image. Example:
129+
///
130+
/// ```dart
131+
/// final samples = (await fromCsv('path/to/dataset.csv'));
132+
/// final classifier = DecisionTreeClassifier(
133+
/// samples,
134+
/// 'target',
135+
/// minError: 0.3,
136+
/// minSamplesCount: 5,
137+
/// maxDepth: 4,
138+
/// );
139+
//
140+
// await classifier.saveAsSvg('tree.svg');
141+
/// ```
142+
///
143+
/// The file 'tree.svg' now contains a graphical representation of the tree
144+
Future<File> saveAsSvg(String filePath);
125145
}

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import 'dart:io';
2+
13
import 'package:json_annotation/json_annotation.dart';
24
import 'package:ml_algo/src/classifier/_mixins/assessable_classifier_mixin.dart';
35
import 'package:ml_algo/src/classifier/decision_tree_classifier/_injector.dart';
46
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
57
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_constants.dart';
68
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
79
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_json_keys.dart';
10+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/create_tree_svg_markup.dart';
811
import 'package:ml_algo/src/common/constants/common_json_keys.dart';
912
import 'package:ml_algo/src/common/json_converter/dtype_json_converter.dart';
1013
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
@@ -143,4 +146,12 @@ class DecisionTreeClassifierImpl
143146
maxDepth,
144147
);
145148
}
149+
150+
@override
151+
Future<File> saveAsSvg(String filePath) async {
152+
final markup = createTreeSvgMarkup(treeRootNode);
153+
final file = await File(filePath).create(recursive: true);
154+
155+
return file.writeAsString(markup);
156+
}
146157
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/create_tree_svg_markup_constants.dart';
2+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/get_tree_levels.dart';
3+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/get_tree_node_distance_by_level.dart';
4+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/get_tree_node_markup.dart';
5+
import 'package:ml_algo/src/classifier/decision_tree_classifier/helpers/create_tree_svg_markup/get_tree_width.dart';
6+
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
7+
8+
class _NodesMarkupData {
9+
_NodesMarkupData(
10+
{required this.markup, required this.level, required this.y});
11+
12+
final String markup;
13+
final int level;
14+
final num y;
15+
}
16+
17+
String createTreeSvgMarkup(TreeNode node) {
18+
final shape = node.shape;
19+
final levels = getTreeLevels(node, shape.length);
20+
final nodeDistanceByLevel =
21+
getTreeNodeDistanceByLevel(levels, nodeWidth, minNodeHorizontalDistance);
22+
final totalWidth = getTreeWidth(levels, nodeWidth, minNodeHorizontalDistance);
23+
final totalHeight = shape.length * (nodeHeight + nodeVerticalDistance);
24+
final markup = _generateMarkup(levels, node, nodeDistanceByLevel);
25+
26+
return '<svg xmlns="http://www.w3.org/2000/svg" width="$totalWidth" height="$totalHeight">$textStyles$markup</svg>';
27+
}
28+
29+
String _generateMarkup(
30+
List<List<TreeNode>> levels, TreeNode root, Map<int, num> distByLevel) {
31+
return levels.fold<_NodesMarkupData>(
32+
_NodesMarkupData(markup: '', level: 0, y: 20), (data, nodes) {
33+
final spacing = distByLevel[data.level]!;
34+
final childSpacing = distByLevel.containsKey(data.level + 1)
35+
? distByLevel[data.level + 1]!
36+
: null;
37+
final getX =
38+
(int idx) => spacing / 2 + (idx == 0 ? 0 : idx * (nodeWidth + spacing));
39+
40+
var nodeIdx = 0;
41+
final nodesMarkup = nodes
42+
.map((node) =>
43+
getTreeNodeMarkup(node, getX(nodeIdx++), data.y, childSpacing))
44+
.join();
45+
46+
return _NodesMarkupData(
47+
markup: '${data.markup}$nodesMarkup',
48+
level: data.level + 1,
49+
y: data.y + nodeHeight + nodeVerticalDistance,
50+
);
51+
}).markup;
52+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
const nodeWidth = 200;
2+
const nodeHeight = 80;
3+
const minNodeHorizontalDistance = 20;
4+
const nodeVerticalDistance = 150;
5+
6+
const textStyles = '<style>'
7+
'.label { font: 14px sans-serif; }'
8+
'.value { font: bold 14px sans-serif; }'
9+
'.root-node-label { font: 24px sans-serif; }'
10+
'</style>';
11+
12+
const nodeStyle =
13+
'fill:blue;fill-opacity:.25;stroke:cornflowerblue;stroke-width:2';
14+
const nodeLineStyle = 'stroke:cornflowerblue;stroke-width:1';
15+
16+
const labelMargin = 20;
17+
const labelWidth = 100;
18+
const labelHeight = 20;
19+
const noValue = '-';

0 commit comments

Comments
 (0)