@@ -99,14 +99,29 @@ def screen_features(self, sample_indice):
9999 if feature_range < self .EPSILON :
100100 continue
101101
102- best_impurity = np .inf
103- split_points = np .quantile (sortted_feature , np .linspace (0 , 1 , self .n_screen_grid + 2 )[1 :- 1 ], interpolation = 'lower' )
104- for split_point in split_points :
105-
106- i = abs (sortted_feature - split_point ).argmin ()
102+ split_point = 0
103+ for i , _ in enumerate (sortted_indice ):
104+
105+ if i == (n_samples - 1 ):
106+ continue
107+
107108 if ((i + 1 ) < self .min_samples_leaf ) or ((n_samples - i - 1 ) < self .min_samples_leaf ):
108109 continue
110+
111+ if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
112+ continue
113+
114+ percentage = (split_point + 1 ) / (self .n_screen_grid + 1 )
115+ if n_samples > self .min_samples_leaf * (self .n_screen_grid + 1 ):
116+ if (i + 1 ) / n_samples < percentage :
117+ continue
118+ elif n_samples > 2 * self .min_samples_leaf :
119+ if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < percentage :
120+ continue
121+ elif (i + 1 ) != self .min_samples_leaf :
122+ continue
109123
124+ split_point += 1
110125 left_indice = sortted_indice [:(i + 1 )]
111126 if node_y [left_indice ].std () == 0 :
112127 left_impurity = 0
@@ -152,14 +167,29 @@ def node_split(self, sample_indice):
152167 if feature_range < self .EPSILON :
153168 continue
154169
155- best_impurity = np .inf
156- split_points = np .quantile (sortted_feature , np .linspace (0 , 1 , self .n_split_grid + 2 )[1 :- 1 ], interpolation = 'lower' )
157- for split_point in split_points :
170+ split_point = 0
171+ for i , _ in enumerate (sortted_indice ):
172+
173+ if i == (n_samples - 1 ):
174+ continue
158175
159- i = abs (sortted_feature - split_point ).argmin ()
160176 if ((i + 1 ) < self .min_samples_leaf ) or ((n_samples - i - 1 ) < self .min_samples_leaf ):
161177 continue
178+
179+ if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
180+ continue
181+
182+ percentage = (split_point + 1 ) / (self .n_screen_grid + 1 )
183+ if n_samples > self .min_samples_leaf * (self .n_split_grid + 1 ):
184+ if (i + 1 ) / n_samples < percentage :
185+ continue
186+ elif n_samples > 2 * self .min_samples_leaf :
187+ if (i + 1 - self .min_samples_leaf ) / (n_samples - 2 * self .min_samples_leaf ) < percentage :
188+ continue
189+ elif (i + 1 ) != self .min_samples_leaf :
190+ continue
162191
192+ split_point += 1
163193 left_indice = sortted_indice [:(i + 1 )]
164194 if node_y [left_indice ].std () == 0 :
165195 left_impurity = 0
0 commit comments