Skip to content

Commit a2f4ea9

Browse files
committed
Avoid double impurity calculation
1 parent a0054d1 commit a2f4ea9

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

econml/tree/_criterion.pyx

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ cdef class Criterion:
209209
return (- self.weighted_n_right * impurity_right
210210
- self.weighted_n_left * impurity_left)
211211

212-
cdef double impurity_improvement(self, double impurity) nogil:
212+
cdef double impurity_improvement(self, float64_t impurity_parent,
213+
float64_t impurity_left,
214+
float64_t impurity_right) nogil:
213215
"""Compute the improvement in impurity
214216
This method computes the improvement in impurity when a split occurs.
215217
The weighted impurity improvement equation is the following:
@@ -218,22 +220,24 @@ cdef class Criterion:
218220
where N is the total number of samples, N_t is the number of samples
219221
at the current node, N_t_L is the number of samples in the left child,
220222
and N_t_R is the number of samples in the right child,
223+
221224
Parameters
222225
----------
223-
impurity : double
224-
The initial impurity of the node before the split
226+
impurity_parent : float64_t
227+
The initial impurity of the parent node before the split
228+
229+
impurity_left : float64_t
230+
The impurity of the left child
231+
232+
impurity_right : float64_t
233+
The impurity of the right child
234+
225235
Return
226236
------
227237
double : improvement in impurity after the split occurs
228238
"""
229-
230-
cdef double impurity_left
231-
cdef double impurity_right
232-
233-
self.children_impurity(&impurity_left, &impurity_right)
234-
235239
return ((self.weighted_n_node_samples / self.weighted_n_samples) *
236-
(impurity - (self.weighted_n_right /
240+
(impurity_parent - (self.weighted_n_right /
237241
self.weighted_n_node_samples * impurity_right)
238242
- (self.weighted_n_left /
239243
self.weighted_n_node_samples * impurity_left)))

econml/tree/_splitter.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -615,11 +615,6 @@ cdef class BestSplitter(Splitter):
615615
if self.honest:
616616
self.criterion_val.reset()
617617
self.criterion_val.update(best.pos_val)
618-
# Calculate a more accurate version of impurity improvement using the input baseline impurity
619-
# passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate
620-
# this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity()
621-
# on the parent node, when that node was split.
622-
best.improvement = self.criterion.impurity_improvement(impurity)
623618
# if we need children impurities by the builder, then we populate these entries
624619
# otherwise, we leave them blank to avoid the extra computation.
625620
if not self.is_children_impurity_proxy():
@@ -630,6 +625,16 @@ cdef class BestSplitter(Splitter):
630625
else:
631626
best.impurity_left_val = best.impurity_left
632627
best.impurity_right_val = best.impurity_right
628+
629+
# Calculate a more accurate version of impurity improvement using the input baseline impurity
630+
# passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate
631+
# this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity()
632+
# on the parent node, when that node was split.
633+
best.improvement = self.criterion.impurity_improvement(impurity,
634+
best.impurity_left, best.impurity_right)
635+
636+
if self.honest:
637+
633638

634639
# Respect invariant for constant features: the original order of
635640
# element in features[:n_known_constants] must be preserved for sibling

0 commit comments

Comments
 (0)