@@ -92,6 +92,57 @@ public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft()
92
92
Assert . That ( tree . Predict ( new [ ] { 3 } ) , Is . EqualTo ( 0 ) ) ;
93
93
}
94
94
95
+ [ Test ]
96
+ public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft_MultipleLabels ( )
97
+ {
98
+ int [ ] [ ] X = { new [ ] { 0 } , new [ ] { 1 } , new [ ] { 2 } , new [ ] { 3 } } ;
99
+ int [ ] y = { 1 , 0 , 1 , 0 } ;
100
+ var tree = new DecisionTree ( ) ;
101
+ tree . Fit ( X , y ) ;
102
+ // Most common label is 0 (2 times)
103
+ Assert . That ( tree . Predict ( new [ ] { 4 } ) , Is . EqualTo ( 0 ) ) ;
104
+ }
105
+
106
+ [ Test ]
107
+ public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsZero ( )
108
+ {
109
+ int [ ] [ ] X = { new [ ] { 0 } , new [ ] { 1 } } ;
110
+ int [ ] y = { 0 , 0 } ;
111
+ var tree = new DecisionTree ( ) ;
112
+ tree . Fit ( X , y ) ;
113
+ Assert . That ( tree . Predict ( new [ ] { 0 } ) , Is . EqualTo ( 0 ) ) ;
114
+ Assert . That ( tree . Predict ( new [ ] { 1 } ) , Is . EqualTo ( 0 ) ) ;
115
+ }
116
+
117
+ [ Test ]
118
+ public void Entropy_ReturnsZero_WhenAllZeroOrAllOne ( )
119
+ {
120
+ var method = typeof ( DecisionTree ) . GetMethod ( "Entropy" , System . Reflection . BindingFlags . NonPublic | System . Reflection . BindingFlags . Static ) ;
121
+ Assert . That ( method ! . Invoke ( null , new object [ ] { new int [ ] { 0 , 0 , 0 } } ) , Is . EqualTo ( 0d ) ) ;
122
+ Assert . That ( method ! . Invoke ( null , new object [ ] { new int [ ] { 1 , 1 , 1 } } ) , Is . EqualTo ( 0d ) ) ;
123
+ }
124
+
125
+ [ Test ]
126
+ public void MostCommon_ReturnsCorrectLabel ( )
127
+ {
128
+ var method = typeof ( DecisionTree ) . GetMethod ( "MostCommon" , System . Reflection . BindingFlags . NonPublic | System . Reflection . BindingFlags . Static ) ;
129
+ Assert . That ( method ! . Invoke ( null , new object [ ] { new int [ ] { 1 , 0 , 1 , 1 , 0 , 0 , 0 } } ) , Is . EqualTo ( 0 ) ) ;
130
+ Assert . That ( method ! . Invoke ( null , new object [ ] { new int [ ] { 1 , 1 , 1 , 0 } } ) , Is . EqualTo ( 1 ) ) ;
131
+ }
132
+
133
+ [ Test ]
134
+ public void Traverse_FallbacksToZero_WhenChildrenIsNull ( )
135
+ {
136
+ // Create a node with Children = null and Label = null
137
+ var nodeType = typeof ( DecisionTree ) . GetNestedType ( "Node" , System . Reflection . BindingFlags . NonPublic ) ;
138
+ var node = Activator . CreateInstance ( nodeType ! ) ;
139
+ nodeType ! . GetProperty ( "Feature" ) ! . SetValue ( node , 0 ) ;
140
+ nodeType ! . GetProperty ( "Label" ) ! . SetValue ( node , null ) ;
141
+ nodeType ! . GetProperty ( "Children" ) ! . SetValue ( node , null ) ;
142
+ var method = typeof ( DecisionTree ) . GetMethod ( "Traverse" , System . Reflection . BindingFlags . NonPublic | System . Reflection . BindingFlags . Static ) ;
143
+ Assert . That ( method ! . Invoke ( null , new object [ ] { node ! , new int [ ] { 99 } } ) , Is . EqualTo ( 0 ) ) ;
144
+ }
145
+
95
146
[ Test ]
96
147
public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsSame ( )
97
148
{
@@ -120,8 +171,8 @@ public void BestFeature_SkipsEmptyIdxBranch()
120
171
int [ ] y = { 0 , 1 } ;
121
172
var method = typeof ( DecisionTree ) . GetMethod ( "BestFeature" , System . Reflection . BindingFlags . NonPublic | System . Reflection . BindingFlags . Static ) ;
122
173
var features = new System . Collections . Generic . List < int > { 0 , 1 } ;
123
- var resultObj = method ! . Invoke ( null , new object [ ] { X , y , features } ) ;
124
- Assert . That ( resultObj , Is . Not . Null ) ;
125
- Assert . That ( ( int ) resultObj ! , Is . EqualTo ( 0 ) ) ;
174
+ var resultObj = method ! . Invoke ( null , new object [ ] { X , y , features } ) ;
175
+ Assert . That ( resultObj , Is . Not . Null ) ;
176
+ Assert . That ( ( int ) resultObj ! , Is . EqualTo ( 0 ) ) ;
126
177
}
127
178
}
0 commit comments