Skip to content

Commit d2d41f6

Browse files
author
[zebinyang]
committed
change the way of spreading grid points; version 0.2.0
1 parent e52f722 commit d2d41f6

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

simtree/mobtree.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,29 @@ def screen_features(self, sample_indice):
9999
if feature_range < self.EPSILON:
100100
continue
101101

102-
best_impurity = np.inf
103-
split_points = np.quantile(sortted_feature, np.linspace(0, 1, self.n_screen_grid + 2)[1:-1], interpolation='lower')
104-
for split_point in split_points:
105-
106-
i = abs(sortted_feature - split_point).argmin()
102+
split_point = 0
103+
for i, _ in enumerate(sortted_indice):
104+
105+
if i == (n_samples - 1):
106+
continue
107+
107108
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
108109
continue
110+
111+
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
112+
continue
113+
114+
percentage = (split_point + 1) / (self.n_screen_grid + 1)
115+
if n_samples > self.min_samples_leaf * (self.n_screen_grid + 1):
116+
if (i + 1) / n_samples < percentage:
117+
continue
118+
elif n_samples > 2 * self.min_samples_leaf:
119+
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < percentage:
120+
continue
121+
elif (i + 1) != self.min_samples_leaf:
122+
continue
109123

124+
split_point += 1
110125
left_indice = sortted_indice[:(i + 1)]
111126
if node_y[left_indice].std() == 0:
112127
left_impurity = 0
@@ -152,14 +167,29 @@ def node_split(self, sample_indice):
152167
if feature_range < self.EPSILON:
153168
continue
154169

155-
best_impurity = np.inf
156-
split_points = np.quantile(sortted_feature, np.linspace(0, 1, self.n_split_grid + 2)[1:-1], interpolation='lower')
157-
for split_point in split_points:
170+
split_point = 0
171+
for i, _ in enumerate(sortted_indice):
172+
173+
if i == (n_samples - 1):
174+
continue
158175

159-
i = abs(sortted_feature - split_point).argmin()
160176
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
161177
continue
178+
179+
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
180+
continue
181+
182+
percentage = (split_point + 1) / (self.n_screen_grid + 1)
183+
if n_samples > self.min_samples_leaf * (self.n_split_grid + 1):
184+
if (i + 1) / n_samples < percentage:
185+
continue
186+
elif n_samples > 2 * self.min_samples_leaf:
187+
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < percentage:
188+
continue
189+
elif (i + 1) != self.min_samples_leaf:
190+
continue
162191

192+
split_point += 1
163193
left_indice = sortted_indice[:(i + 1)]
164194
if node_y[left_indice].std() == 0:
165195
left_impurity = 0

0 commit comments

Comments
 (0)