Skip to content

Commit b5890e6

Browse files
authored
Greedy splitter: the case of a split column consisting of the same values (#213)
1 parent 4436b17 commit b5890e6

File tree

11 files changed

+142
-27
lines changed

11 files changed

+142
-27
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.7.1
4+
- DecisionTreeClassifier:
5+
- Fixed greedy splitter in case of a split column consisting of the same values
6+
37
## 16.7.0
48
- DecisionTreeClassifier:
59
- Added `saveAsSvg` method which returns '.svg' file with a graphical representation of a tree
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// MacBook Air 13.3 mid 2017: ~ 6.5 sec
2+
3+
import 'dart:math';
4+
5+
import 'package:benchmark_harness/benchmark_harness.dart';
6+
import 'package:ml_algo/ml_algo.dart';
7+
import 'package:ml_dataframe/ml_dataframe.dart';
8+
import 'package:ml_linalg/linalg.dart';
9+
10+
const observationsNum = 300;
11+
const columnsNum = 11;
12+
13+
class DecisionTreeClassifierBenchmark extends BenchmarkBase {
14+
DecisionTreeClassifierBenchmark()
15+
: super('Decision tree classifier benchmark');
16+
17+
late DataFrame _data;
18+
19+
static void main() {
20+
DecisionTreeClassifierBenchmark().report();
21+
}
22+
23+
@override
24+
void run() {
25+
DecisionTreeClassifier(
26+
_data,
27+
'col_10',
28+
maxDepth: 4,
29+
minError: 0.4,
30+
minSamplesCount: 10,
31+
);
32+
}
33+
34+
@override
35+
void setup() {
36+
final random = Random(1);
37+
final observations =
38+
Matrix.random(observationsNum, columnsNum - 1, seed: 1);
39+
final outcomes = Vector.fromList([
40+
...List.filled(observationsNum ~/ 2, 1),
41+
...List.filled(observationsNum ~/ 2, 0)
42+
]..shuffle(random));
43+
44+
_data = DataFrame(observations.insertColumns(columnsNum - 1, [outcomes]),
45+
headerExists: false);
46+
}
47+
48+
void tearDown() {}
49+
}
50+
51+
Future main() async {
52+
DecisionTreeClassifierBenchmark.main();
53+
}

benchmark/logistic_regressor.dart

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
// MacBook Air 13.3 mid 2017: ~ 4 sec
2+
13
import 'package:benchmark_harness/benchmark_harness.dart';
24
import 'package:ml_algo/ml_algo.dart';
35
import 'package:ml_dataframe/ml_dataframe.dart';
46
import 'package:ml_linalg/matrix.dart';
5-
import 'package:ml_linalg/vector.dart';
67

7-
const observationsNum = 200;
8-
const columnsNum = 21;
8+
const observationsNum = 20000;
9+
const columnsNum = 101;
910

1011
class LogisticRegressorBenchmark extends BenchmarkBase {
1112
LogisticRegressorBenchmark() : super('Logistic regressor');
@@ -28,8 +29,7 @@ class LogisticRegressorBenchmark extends BenchmarkBase {
2829

2930
@override
3031
void setup() {
31-
final observations = Matrix.fromRows(
32-
List.generate(observationsNum, (i) => Vector.randomFilled(columnsNum)));
32+
final observations = Matrix.random(observationsNum, columnsNum, seed: 1);
3333

3434
_data = DataFrame.fromMatrix(observations);
3535
}

e2e/decision_tree_classifier/decision_tree_classifier_save_as_svg_test.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ void main() async {
3636
final classifier = DecisionTreeClassifier(
3737
samples,
3838
'class variable (0 or 1)',
39-
minError: 0.1,
40-
minSamplesCount: 2,
41-
maxDepth: 4,
39+
minError: 0.15,
40+
minSamplesCount: 1,
41+
maxDepth: 5,
4242
);
4343

4444
await classifier
Lines changed: 1 addition & 1 deletion
Loading

e2e/decision_tree_classifier/pima_indians_tree.svg

Lines changed: 1 addition & 1 deletion
Loading

lib/src/tree_trainer/split_assessor/majority_split_assessor.dart

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ class MajorityTreeSplitAssessor implements TreeSplitAssessor {
1212
double getAggregatedError(Iterable<Matrix> splitObservations, int targetId) {
1313
var errorCount = 0;
1414
var totalCount = 0;
15-
for (final nodeObservations in splitObservations) {
16-
if (nodeObservations.columnsNum == 0) {
17-
throw Exception('Observations on the node are empty');
18-
}
15+
16+
for (final nodeObservations in splitObservations
17+
.where((observations) => observations.columnsNum > 0)) {
1918
if (targetId >= nodeObservations.columnsNum) {
2019
throw ArgumentError.value(
2120
targetId,
@@ -26,6 +25,7 @@ class MajorityTreeSplitAssessor implements TreeSplitAssessor {
2625
errorCount += _getErrorCount(nodeObservations.getColumn(targetId));
2726
totalCount += nodeObservations.rowsNum;
2827
}
28+
2929
return errorCount / totalCount;
3030
}
3131

lib/src/tree_trainer/splitter/greedy_splitter.dart

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@ class GreedyTreeSplitter implements TreeSplitter {
3535
Map<TreeNode, Matrix> _createByNumericalValue(
3636
Matrix samples, int splittingIdx, int targetId) {
3737
final errors = <double, List<Map<TreeNode, Matrix>>>{};
38-
final sortedRows = samples.sort((row) => row[splittingIdx], Axis.rows).rows;
38+
final sortedRows =
39+
samples.sort((row) => row[splittingIdx], Axis.rows).rows.toList();
40+
final rowsWithoutFirst = sortedRows.skip(1).toList();
3941
var prevValue = sortedRows.first[splittingIdx];
4042

41-
for (final row in sortedRows.skip(1)) {
43+
for (var i = 0; i < rowsWithoutFirst.length; i++) {
44+
final row = rowsWithoutFirst[i];
4245
final nextValue = row[splittingIdx];
4346

44-
if (prevValue == nextValue) {
47+
if (prevValue == nextValue && i < sortedRows.length - 2) {
4548
continue;
4649
}
4750

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.7.0
3+
version: 16.7.1
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

test/tree_trainer/split_assessor/majority_split_assesor_test.dart

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void main() {
120120
});
121121

122122
test(
123-
'should return majority-based error, that is equal to 0, if all '
123+
'should return majority-based error that is equal to 0, if all '
124124
'nodes in the stump have only one observation', () {
125125
final node1 = Matrix.fromList([
126126
[50, 70, 0],
@@ -141,9 +141,7 @@ void main() {
141141
expect(error, 0);
142142
});
143143

144-
test(
145-
'should throw an error if at least one node in the stump does not '
146-
'have observations at all', () {
144+
test('should ignore empty split matrices', () {
147145
final node1 = Matrix.fromList([]);
148146

149147
final node2 = Matrix.fromList([
@@ -156,10 +154,7 @@ void main() {
156154

157155
final stump = [node1, node2, node3];
158156

159-
expect(
160-
() => const MajorityTreeSplitAssessor().getAggregatedError(stump, 2),
161-
throwsException,
162-
);
157+
expect(const MajorityTreeSplitAssessor().getAggregatedError(stump, 2), 0);
163158
});
164159

165160
test(

0 commit comments

Comments
 (0)