Skip to content

Commit 5b5fed2

Browse files
author
[zebinyang]
committed
add get_sparsity; add get_roughness; version 0.1.6
1 parent 9824cf2 commit 5b5fed2

File tree

5 files changed

+136
-8
lines changed

5 files changed

+136
-8
lines changed

examples/demo.ipynb

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,34 @@
193193
{
194194
"cell_type": "code",
195195
"execution_count": null,
196+
"metadata": {},
197+
"outputs": [],
198+
"source": []
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": 7,
196203
"metadata": {
197204
"ExecuteTime": {
198-
"end_time": "2021-09-01T17:37:22.054020Z",
199-
"start_time": "2021-09-01T17:29:18.371Z"
205+
"end_time": "2021-09-01T17:49:13.840143Z",
206+
"start_time": "2021-09-01T17:45:39.855056Z"
200207
}
201208
},
202-
"outputs": [],
209+
"outputs": [
210+
{
211+
"ename": "TypeError",
212+
"evalue": "'>=' not supported between instances of 'numpy.ndarray' and 'str'",
213+
"output_type": "error",
214+
"traceback": [
215+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
216+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
217+
"\u001b[0;32m<ipython-input-7-8582c77c0a6d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n\u001b[1;32m 3\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mpred_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mpred_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mroc_auc_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroc_auc_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpred_test\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
218+
"\u001b[0;32m~/anaconda2_local/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36mpredict_proba\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 653\u001b[0;31m \u001b[0mproba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecision_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 654\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 655\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
219+
"\u001b[0;32m~/anaconda2_local/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36mdecision_function\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'is_leaf'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 505\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'feature'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'threshold'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 506\u001b[0m \u001b[0mnode\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'left_child_id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 507\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
220+
"\u001b[0;31mTypeError\u001b[0m: '>=' not supported between instances of 'numpy.ndarray' and 'str'"
221+
]
222+
}
223+
],
203224
"source": [
204225
"clf = GLMTreeClassifier(max_depth=3, min_samples_leaf=50, reg_lambda=np.logspace(-5, 5, 10).tolist(),\n",
205226
" n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n",
@@ -211,9 +232,53 @@
211232
},
212233
{
213234
"cell_type": "code",
214-
"execution_count": null,
215-
"metadata": {},
235+
"execution_count": 13,
236+
"metadata": {
237+
"ExecuteTime": {
238+
"end_time": "2021-09-01T17:53:47.559641Z",
239+
"start_time": "2021-09-01T17:53:47.530587Z"
240+
}
241+
},
216242
"outputs": [],
243+
"source": [
244+
"leaf_idx = []\n",
245+
"for row in train_x.values:\n",
246+
" node = clf.tree[1]\n",
247+
" while not node['is_leaf']:\n",
248+
" if row[node['feature']] <= node['threshold']:\n",
249+
" node = clf.tree[node['left_child_id']]\n",
250+
" else:\n",
251+
" node = clf.tree[node['right_child_id']]\n",
252+
" leaf_idx.append(node['node_id'])"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": 14,
258+
"metadata": {
259+
"ExecuteTime": {
260+
"end_time": "2021-09-01T17:54:00.517384Z",
261+
"start_time": "2021-09-01T17:54:00.511025Z"
262+
}
263+
},
264+
"outputs": [
265+
{
266+
"data": {
267+
"text/plain": [
268+
"array([[ 57.5644, 20.2196, 2.683 , ..., 11.8193, 26.1085, 217.544 ],\n",
269+
" [ 27.7996, 14.0561, 2.6839, ..., -9.2172, 38.023 , 97.7341],\n",
270+
" [ 48.4661, 22.7264, 2.953 , ..., 24.8371, 4.825 , 266.665 ],\n",
271+
" ...,\n",
272+
" [ 35.8286, 16.8952, 2.8802, ..., 11.3048, 0.472 , 234.868 ],\n",
273+
" [ 20.0986, 12.8671, 2.4057, ..., 7.875 , 21.675 , 212.098 ],\n",
274+
" [ 27.2726, 12.6129, 2.7288, ..., -9.9008, 3.789 , 185.431 ]])"
275+
]
276+
},
277+
"execution_count": 14,
278+
"metadata": {},
279+
"output_type": "execute_result"
280+
}
281+
],
217282
"source": []
218283
}
219284
],

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='simtree',
4-
version='0.1.5',
4+
version='0.1.6',
55
description='Single-index model tree',
66
url='https://github.com/ZebinYang/SIMTree',
77
author='Zebin Yang',

simtree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
"SIMTreeRegressor", "SIMTreeClassifier",
99
"CustomMobTreeRegressor", "CustomMobTreeClassifier"]
1010

11-
__version__ = '0.1.5'
11+
__version__ = '0.1.6'
1212
__author__ = 'Zebin Yang'

simtree/mobtree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def decision_function(self, x):
499499
check_is_fitted(self, "tree")
500500

501501
leaf_idx = []
502-
for row in x:
502+
for row in np.array(x):
503503
node = self.tree[1]
504504
while not node['is_leaf']:
505505
if row[node['feature']] <= node['threshold']:

simtree/simtree.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def get_projection_index(self, node_id):
9595
node_id : int
9696
the id of leaf node
9797
"""
98+
99+
check_is_fitted(self, "tree")
100+
if node_id not in self.leaf_estimators_.keys():
101+
print("Invalid leaf node id.")
102+
return
103+
98104
return self.leaf_estimators_[node_id].beta_.flatten()
99105

100106
def get_feature_importance(self, node_id):
@@ -106,6 +112,12 @@ def get_feature_importance(self, node_id):
106112
node_id : int
107113
the id of leaf node
108114
"""
115+
116+
check_is_fitted(self, "tree")
117+
if node_id not in self.leaf_estimators_.keys():
118+
print("Invalid leaf node id.")
119+
return
120+
109121
importance = (self.x[self.decision_path_indice(self.x, node_id)] * self.leaf_estimators_[node_id].beta_.ravel()).std(0)
110122
return importance
111123

@@ -120,6 +132,12 @@ def get_projection_equation(self, node_id, precision=3):
120132
precision : int
121133
the precision of coefficients
122134
"""
135+
136+
check_is_fitted(self, "tree")
137+
if node_id not in self.leaf_estimators_.keys():
138+
print("Invalid leaf node id.")
139+
return
140+
123141
equation = ""
124142
importance = self.get_feature_importance(node_id)
125143
sortind = np.argsort(importance)[::-1]
@@ -135,6 +153,51 @@ def get_projection_equation(self, node_id, precision=3):
135153
equation += " - "
136154
equation += str(round(np.abs(est.beta_[sortind[i], 0]), 3)) + self.feature_names[sortind[i]]
137155
return equation
156+
157+
def get_sparsity(self, node_id, grid_size=100):
158+
159+
"""return the sparsity of the projection index in one leaf node, i.e., the percentage of zero coefficients.
160+
161+
Parameters
162+
---------
163+
node_id : int
164+
the id of leaf node
165+
"""
166+
167+
check_is_fitted(self, "tree")
168+
if node_id not in self.leaf_estimators_.keys():
169+
print("Invalid leaf node id.")
170+
return
171+
172+
est = self.leaf_estimators_[node_id]
173+
sparsity = np.mean(est.beta_ == 0)
174+
return sparsity
175+
176+
def get_roughness(self, node_id, grid_size=100):
177+
178+
"""return the roughness of the ridge function in one leaf node, i.e., the root-mean-square second derivative of the ridge function.
179+
180+
Parameters
181+
---------
182+
node_id : int
183+
the id of leaf node
184+
grid_size : int
185+
the number of grid points for approximation
186+
"""
187+
188+
check_is_fitted(self, "tree")
189+
if node_id not in self.leaf_estimators_.keys():
190+
print("Invalid leaf node id.")
191+
return
192+
193+
if self.leaf_estimators_[node_id] is None:
194+
print("This is a constant node, and SIM is not available.")
195+
return
196+
197+
est = self.leaf_estimators_[node_id]
198+
xgrid = np.linspace(est.shape_fit_.xmin, est.shape_fit_.xmax, grid_size + 2)[1:-1]
199+
roughness = np.sqrt(np.mean([est.shape_fit_.diff(x, order=2) ** 2 for x in xgrid]))
200+
return roughness
138201

139202
def visualize_one_leaf(self, node_id, folder="./results/", name="leaf_sim", save_png=False, save_eps=False):
140203

0 commit comments

Comments
 (0)