@@ -99,29 +99,14 @@ def screen_features(self, sample_indice):
9999 if feature_range < self .EPSILON :
100100 continue
101101
102- split_point = 0
103102 best_impurity = np .inf
104- for i , _ in enumerate (sortted_indice ):
105-
106- if i == (n_samples - 1 ):
107- continue
108-
103+ split_points = np .quantile (sortted_feature , np .linspace (0 , 1 , self .n_screen_grid + 2 )[1 :- 1 ], interpolation = 'lower' )
104+ for split_point in split_points :
105+
106+ i = abs (sortted_feature - split_point ).argmin ()
109107 if ((i + 1 ) < self .min_samples_leaf ) or ((n_samples - i - 1 ) < self .min_samples_leaf ):
110108 continue
111109
112- if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
113- continue
114-
115- if self .min_samples_leaf < n_samples / max ((self .n_screen_grid - 1 ), 2 ):
116- if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_screen_grid + 1 ):
117- continue
118- elif n_samples > 2 * self .min_samples_leaf :
119- if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < split_point / (self .n_screen_grid - 1 ):
120- continue
121- elif (i + 1 ) != self .min_samples_leaf :
122- continue
123-
124- split_point += 1
125110 left_indice = sortted_indice [:(i + 1 )]
126111 if node_y [left_indice ].std () == 0 :
127112 left_impurity = 0
@@ -167,28 +152,14 @@ def node_split(self, sample_indice):
167152 if feature_range < self .EPSILON :
168153 continue
169154
170- split_point = 0
171- for i , _ in enumerate (sortted_indice ):
172-
173- if i == (n_samples - 1 ):
174- continue
155+ best_impurity = np .inf
156+ split_points = np .quantile (sortted_feature , np .linspace (0 , 1 , self .n_split_grid + 2 )[1 :- 1 ], interpolation = 'lower' )
157+ for split_point in split_points :
175158
159+ i = abs (sortted_feature - split_point ).argmin ()
176160 if ((i + 1 ) < self .min_samples_leaf ) or ((n_samples - i - 1 ) < self .min_samples_leaf ):
177161 continue
178-
179- if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
180- continue
181-
182- if self .min_samples_leaf < n_samples / max ((self .n_split_grid - 1 ), 2 ):
183- if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_split_grid + 1 ):
184- continue
185- elif n_samples > 2 * self .min_samples_leaf :
186- if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < split_point / (self .n_split_grid - 1 ):
187- continue
188- elif (i + 1 ) != self .min_samples_leaf :
189- continue
190162
191- split_point += 1
192163 left_indice = sortted_indice [:(i + 1 )]
193164 if node_y [left_indice ].std () == 0 :
194165 left_impurity = 0
0 commit comments