Skip to content

Commit b3e3b0d

Browse files
committed
Add more unit tests for Decision Tree and handle the exception
1 parent 0d1f766 commit b3e3b0d

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

Algorithms.Tests/MachineLearning/DecisionTreeTests.cs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,57 @@ public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft()
9292
Assert.That(tree.Predict(new[] { 3 }), Is.EqualTo(0));
9393
}
9494

95+
[Test]
96+
public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft_MultipleLabels()
97+
{
98+
int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 }, new[] { 3 } };
99+
int[] y = { 1, 0, 1, 0 };
100+
var tree = new DecisionTree();
101+
tree.Fit(X, y);
102+
// Most common label is 0 (2 times)
103+
Assert.That(tree.Predict(new[] { 4 }), Is.EqualTo(0));
104+
}
105+
106+
[Test]
107+
public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsZero()
108+
{
109+
int[][] X = { new[] { 0 }, new[] { 1 } };
110+
int[] y = { 0, 0 };
111+
var tree = new DecisionTree();
112+
tree.Fit(X, y);
113+
Assert.That(tree.Predict(new[] { 0 }), Is.EqualTo(0));
114+
Assert.That(tree.Predict(new[] { 1 }), Is.EqualTo(0));
115+
}
116+
117+
[Test]
118+
public void Entropy_ReturnsZero_WhenAllZeroOrAllOne()
119+
{
120+
var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
121+
Assert.That(method!.Invoke(null, new object[] { new int[] { 0, 0, 0 } }), Is.EqualTo(0d));
122+
Assert.That(method!.Invoke(null, new object[] { new int[] { 1, 1, 1 } }), Is.EqualTo(0d));
123+
}
124+
125+
[Test]
126+
public void MostCommon_ReturnsCorrectLabel()
127+
{
128+
var method = typeof(DecisionTree).GetMethod("MostCommon", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
129+
Assert.That(method!.Invoke(null, new object[] { new int[] { 1, 0, 1, 1, 0, 0, 0 } }), Is.EqualTo(0));
130+
Assert.That(method!.Invoke(null, new object[] { new int[] { 1, 1, 1, 0 } }), Is.EqualTo(1));
131+
}
132+
133+
[Test]
134+
public void Traverse_FallbacksToZero_WhenChildrenIsNull()
135+
{
136+
// Create a node with Children = null and Label = null
137+
var nodeType = typeof(DecisionTree).GetNestedType("Node", System.Reflection.BindingFlags.NonPublic);
138+
var node = Activator.CreateInstance(nodeType!);
139+
nodeType!.GetProperty("Feature")!.SetValue(node, 0);
140+
nodeType!.GetProperty("Label")!.SetValue(node, null);
141+
nodeType!.GetProperty("Children")!.SetValue(node, null);
142+
var method = typeof(DecisionTree).GetMethod("Traverse", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
143+
Assert.That(method!.Invoke(null, new object[] { node!, new int[] { 99 } }), Is.EqualTo(0));
144+
}
145+
95146
[Test]
96147
public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsSame()
97148
{
@@ -120,8 +171,8 @@ public void BestFeature_SkipsEmptyIdxBranch()
120171
int[] y = { 0, 1 };
121172
var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static);
122173
var features = new System.Collections.Generic.List<int> { 0, 1 };
123-
var resultObj = method!.Invoke(null, new object[] { X, y, features });
124-
Assert.That(resultObj, Is.Not.Null);
125-
Assert.That((int)resultObj!, Is.EqualTo(0));
174+
var resultObj = method!.Invoke(null, new object[] { X, y, features });
175+
Assert.That(resultObj, Is.Not.Null);
176+
Assert.That((int)resultObj!, Is.EqualTo(0));
126177
}
127178
}

Algorithms/MachineLearning/DecisionTree.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ private static int Traverse(Node node, int[] x)
9696
}
9797

9898
int v = x[node.Feature!.Value];
99-
if (node.Children!.TryGetValue(v, out var child))
99+
if (node.Children != null && node.Children.TryGetValue(v, out var child))
100100
{
101101
return Traverse(child, x);
102102
}
103103

104-
// fallback to 0 if unseen value
104+
// fallback to 0 if unseen value or Children is null
105105
return 0;
106106
}
107107

0 commit comments

Comments
 (0)