Skip to content

Commit dc94ab1

Browse files
author
[zebinyang]
committed
change the way of spreading grid points; version 0.2.0
1 parent 69e2b80 commit dc94ab1

File tree

4 files changed

+16
-68
lines changed

4 files changed

+16
-68
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='simtree',
4-
version='0.1.8',
4+
version='0.2.0',
55
description='Single-index model tree',
66
url='https://github.com/ZebinYang/SIMTree',
77
author='Zebin Yang',

simtree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
"SIMTreeRegressor", "SIMTreeClassifier",
99
"CustomMobTreeRegressor", "CustomMobTreeClassifier"]
1010

11-
__version__ = '0.1.8'
11+
__version__ = '0.2.0'
1212
__author__ = 'Zebin Yang'

simtree/cart.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,12 @@ def node_split(self, sample_indice):
5555
sq_sum_total = np.sum(node_y ** 2)
5656
for i, _ in enumerate(sortted_indice):
5757

58+
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
59+
continue
60+
5861
n_left = i + 1
5962
n_right = n_samples - i - 1
6063
sum_left += node_y[sortted_indice[i]]
61-
if i == (n_samples - 1):
62-
continue
63-
64-
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
65-
continue
66-
67-
if self.min_samples_leaf < n_samples / (self.n_split_grid - 1):
68-
if (i + 1) / n_samples < (split_point + 1) / (self.n_split_grid + 1):
69-
continue
70-
elif n_samples > 2 * self.min_samples_leaf:
71-
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < split_point / (self.n_split_grid - 1):
72-
continue
73-
elif (i + 1) != self.min_samples_leaf:
74-
continue
75-
7664
current_impurity = (sq_sum_total / n_samples - (sum_left / n_left) ** 2 * n_left / n_samples -
7765
((sum_total - sum_left) / n_right) ** 2 * n_right / n_samples)
7866

@@ -144,23 +132,12 @@ def node_split(self, sample_indice):
144132
sum_total = np.sum(node_y)
145133
for i, _ in enumerate(sortted_indice):
146134

135+
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
136+
continue
137+
147138
n_left = i + 1
148139
n_right = n_samples - i - 1
149140
sum_left += node_y[sortted_indice[i]]
150-
if i == (n_samples - 1):
151-
continue
152-
153-
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
154-
continue
155-
156-
if self.min_samples_leaf < n_samples / (self.n_split_grid - 1):
157-
if (i + 1) / n_samples < (split_point + 1) / (self.n_split_grid + 1):
158-
continue
159-
elif n_samples > 2 * self.min_samples_leaf:
160-
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < split_point / (self.n_split_grid - 1):
161-
continue
162-
elif (i + 1) != self.min_samples_leaf:
163-
continue
164141

165142
left_impurity = 0
166143
right_impurity = 0

simtree/mobtree.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -99,29 +99,14 @@ def screen_features(self, sample_indice):
9999
if feature_range < self.EPSILON:
100100
continue
101101

102-
split_point = 0
103102
best_impurity = np.inf
104-
for i, _ in enumerate(sortted_indice):
105-
106-
if i == (n_samples - 1):
107-
continue
108-
103+
split_points = np.quantile(sortted_feature, np.linspace(0, 1, self.n_screen_grid + 2)[1:-1], interpolation='lower')
104+
for split_point in split_points:
105+
106+
i = abs(sortted_feature - split_point).argmin()
109107
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
110108
continue
111109

112-
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
113-
continue
114-
115-
if self.min_samples_leaf < n_samples / max((self.n_screen_grid - 1), 2):
116-
if (i + 1) / n_samples < (split_point + 1) / (self.n_screen_grid + 1):
117-
continue
118-
elif n_samples > 2 * self.min_samples_leaf:
119-
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < split_point / (self.n_screen_grid - 1):
120-
continue
121-
elif (i + 1) != self.min_samples_leaf:
122-
continue
123-
124-
split_point += 1
125110
left_indice = sortted_indice[:(i + 1)]
126111
if node_y[left_indice].std() == 0:
127112
left_impurity = 0
@@ -167,28 +152,14 @@ def node_split(self, sample_indice):
167152
if feature_range < self.EPSILON:
168153
continue
169154

170-
split_point = 0
171-
for i, _ in enumerate(sortted_indice):
172-
173-
if i == (n_samples - 1):
174-
continue
155+
best_impurity = np.inf
156+
split_points = np.quantile(sortted_feature, np.linspace(0, 1, self.n_split_grid + 2)[1:-1], interpolation='lower')
157+
for split_point in split_points:
175158

159+
i = abs(sortted_feature - split_point).argmin()
176160
if ((i + 1) < self.min_samples_leaf) or ((n_samples - i - 1) < self.min_samples_leaf):
177161
continue
178-
179-
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
180-
continue
181-
182-
if self.min_samples_leaf < n_samples / max((self.n_split_grid - 1), 2):
183-
if (i + 1) / n_samples < (split_point + 1) / (self.n_split_grid + 1):
184-
continue
185-
elif n_samples > 2 * self.min_samples_leaf:
186-
if (i + 1 - self.min_samples_leaf) / (n_samples - 2 * self.min_samples_leaf) < split_point / (self.n_split_grid - 1):
187-
continue
188-
elif (i + 1) != self.min_samples_leaf:
189-
continue
190162

191-
split_point += 1
192163
left_indice = sortted_indice[:(i + 1)]
193164
if node_y[left_indice].std() == 0:
194165
left_impurity = 0

0 commit comments

Comments
 (0)