Skip to content

Commit a304f19

Browse files
committed
Add some special cases for Decision Tree
1 parent d431208 commit a304f19

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NUnit.Framework;
22
using Algorithms.MachineLearning;
3-
using System;
3+
using System;
44

55
namespace Algorithms.Tests.MachineLearning;
66

@@ -118,16 +118,16 @@ public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsZero()
118118
public void Entropy_ReturnsZero_WhenAllZeroOrAllOne()
119119
{
120120
var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
121-
Assert.That(method!.Invoke(null, new [] { new int[] { 0, 0, 0 } }), Is.EqualTo(0d));
122-
Assert.That(method!.Invoke(null, new [] { new int[] { 1, 1, 1 } }), Is.EqualTo(0d));
121+
Assert.That(method!.Invoke(null, new[] { new int[] { 0, 0, 0 } }), Is.EqualTo(0d));
122+
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1 } }), Is.EqualTo(0d));
123123
}
124124

125125
[Test]
126126
public void MostCommon_ReturnsCorrectLabel()
127127
{
128128
var method = typeof(DecisionTree).GetMethod("MostCommon", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
129-
Assert.That(method!.Invoke(null, new [] { new int[] { 1, 0, 1, 1, 0, 0, 0 } }), Is.EqualTo(0));
130-
Assert.That(method!.Invoke(null, new [] { new int[] { 1, 1, 1, 0 } }), Is.EqualTo(1));
129+
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 0, 1, 1, 0, 0, 0 } }), Is.EqualTo(0));
130+
Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1, 0 } }), Is.EqualTo(1));
131131
}
132132

133133
[Test]
@@ -175,4 +175,34 @@ public void BestFeature_SkipsEmptyIdxBranch()
175175
Assert.That(resultObj, Is.Not.Null);
176176
Assert.That((int)resultObj!, Is.EqualTo(0));
177177
}
178+
179+
[Test]
180+
public void BuildTree_MostCommonLabelBranch_IsCovered()
181+
{
182+
int[][] X = { new[] { 0 }, new[] { 1 } };
183+
int[] y = { 0, 1 };
184+
var tree = new DecisionTree();
185+
tree.Fit(X, y);
186+
Assert.That(tree.Predict(new[] { 2 }), Is.EqualTo(0));
187+
}
188+
189+
[Test]
190+
public void BuildTree_ContinueBranch_IsCovered()
191+
{
192+
int[][] X = { new[] { 0 }, new[] { 1 } };
193+
int[] y = { 0, 1 };
194+
var method = typeof(DecisionTree).GetMethod("BuildTree", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
195+
var features = new System.Collections.Generic.List<int> { 0 };
196+
Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features }));
197+
}
198+
199+
[Test]
200+
public void BestFeature_ContinueBranch_IsCovered()
201+
{
202+
int[][] X = { new[] { 0, 1 }, new[] { 1, 1 } };
203+
int[] y = { 0, 1 };
204+
var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
205+
var features = new System.Collections.Generic.List<int> { 0, 1 };
206+
Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features }));
207+
}
178208
}

0 commit comments

Comments
 (0)