Skip to content

Commit dbbd672

Browse files
author
Zebin Yang
committed
precision control for split
1 parent 73cf29f commit dbbd672

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

simtree/mobtree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def screen_features(self, sample_indice):
137137
right_impurity = self.evaluate_estimator(self.base_estimator, node_x[right_indice], node_y[right_indice].ravel())
138138

139139
current_impurity = (len(left_indice) * left_impurity + len(right_indice) * right_impurity) / n_samples
140-
if current_impurity < best_impurity:
140+
if current_impurity < (best_impurity - self.EPSILON):
141141
best_impurity = current_impurity
142142
feature_impurity.append(best_impurity)
143143
split_feature_indices = np.argsort(feature_impurity)[:self.n_feature_search]
@@ -204,7 +204,7 @@ def node_split(self, sample_indice):
204204
right_impurity = self.evaluate_estimator(self.base_estimator, node_x[right_indice], node_y[right_indice].ravel())
205205

206206
current_impurity = (len(left_indice) * left_impurity + len(right_indice) * right_impurity) / n_samples
207-
if current_impurity < best_impurity:
207+
if current_impurity < (best_impurity - self.EPSILON):
208208
best_position = i + 1
209209
best_feature = feature_indice
210210
best_impurity = current_impurity
@@ -290,7 +290,7 @@ def fit(self, x, y):
290290
if not is_leaf:
291291
split = self.node_split(sample_indice)
292292
impurity_improvement = impurity - split["impurity"]
293-
is_leaf = (is_leaf or (impurity_improvement < self.min_impurity_decrease) or
293+
is_leaf = (is_leaf or (impurity_improvement < (self.min_impurity_decrease + self.EPSILON)) or
294294
(split["left"] is None) or (split["right"] is None))
295295

296296
if is_leaf:

0 commit comments

Comments
 (0)