@@ -111,12 +111,11 @@ def screen_features(self, sample_indice):
111111 if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
112112 continue
113113
114- percentage = (split_point + 1 ) / (self .n_screen_grid + 1 )
115- if n_samples > self .min_samples_leaf * (self .n_screen_grid + 1 ):
116- if (i + 1 ) / n_samples < percentage :
114+ if n_samples > (self .n_screen_grid + 1 ) * self .min_samples_leaf :
115+ if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_split_grid + 1 ):
117116 continue
118117 elif n_samples > 2 * self .min_samples_leaf :
119- if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < percentage :
118+ if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < split_point / ( self . n_screen_grid - 1 ) :
120119 continue
121120 elif (i + 1 ) != self .min_samples_leaf :
122121 continue
@@ -179,12 +178,11 @@ def node_split(self, sample_indice):
179178 if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
180179 continue
181180
182- percentage = (split_point + 1 ) / (self .n_screen_grid + 1 )
183- if n_samples > self .min_samples_leaf * (self .n_split_grid + 1 ):
184- if (i + 1 ) / n_samples < percentage :
181+ if n_samples > (self .n_split_grid + 1 ) * self .min_samples_leaf :
182+ if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_split_grid + 1 ):
185183 continue
186184 elif n_samples > 2 * self .min_samples_leaf :
187- if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < percentage :
185+ if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < split_point / ( self . n_split_grid - 1 ) :
188186 continue
189187 elif (i + 1 ) != self .min_samples_leaf :
190188 continue
0 commit comments