Skip to content

Commit 1ba510c

Browse files
author
[zebinyang]
committed
add feature importance for each SIM; get_projection_equation; add get_feature_importance; add get_projection_index; update visualize_one_leaf; add update version 0.1.5
1 parent a887007 commit 1ba510c

File tree

3 files changed

+76
-39
lines changed

3 files changed

+76
-39
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='simtree',
4-
version='0.1.4',
4+
version='0.1.5',
55
description='Single-index model tree',
66
url='https://github.com/ZebinYang/SIMTree',
77
author='Zebin Yang',

simtree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
"SIMTreeRegressor", "SIMTreeClassifier",
99
"CustomMobTreeRegressor", "CustomMobTreeClassifier"]
1010

11-
__version__ = '0.1.4'
11+
__version__ = '0.1.5'
1212
__author__ = 'Zebin Yang'

simtree/simtree.py

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,55 @@ def _validate_hyperparameters(self):
8585
self.reg_gamma = [self.reg_gamma]
8686
else:
8787
raise ValueError("Invalid reg_gamma")
88+
89+
def get_projection_index(self, node_id):
90+
91+
"""return the projection index of one leaf node.
92+
93+
Parameters
94+
---------
95+
node_id : int
96+
the id of leaf node
97+
"""
98+
return self.leaf_estimators_[node_id].beta_.flatten()
99+
100+
def get_feature_importance(self, node_id):
101+
102+
"""return the feature_importance of one leaf node.
103+
104+
Parameters
105+
---------
106+
node_id : int
107+
the id of leaf node
108+
"""
109+
importance = (self.x[self.decision_path_indice(self.x, node_id)] * self.leaf_estimators_[node_id].beta_.ravel()).std(0)
110+
return importance
111+
112+
def get_projection_equation(self, node_id, precision=3):
113+
114+
"""return the projection equation of one leaf node in string format.
115+
116+
Parameters
117+
---------
118+
node_id : int
119+
the id of leaf node
120+
precision : int
121+
the precision of coefficients
122+
"""
123+
equation = ""
124+
importance = self.get_feature_importance(node_id)
125+
sortind = np.argsort(importance)[::-1]
126+
for i in range(est.beta_.shape[0]):
127+
if i == 0:
128+
equation += str(round(np.abs(est.beta_[sortind[i], 0]), 3)) + clf.feature_names[sortind[i]]
129+
continue
130+
else:
131+
if est.beta_[sortind[i], 0] > 0:
132+
equation += " + "
133+
else:
134+
equation += " - "
135+
equation += str(round(np.abs(est.beta_[sortind[i], 0]), 3)) + clf.feature_names[sortind[i]]
136+
return equation
88137

89138
def visualize_one_leaf(self, node_id, folder="./results/", name="leaf_sim", save_png=False, save_eps=False):
90139

@@ -120,7 +169,7 @@ def visualize_one_leaf(self, node_id, folder="./results/", name="leaf_sim", save
120169

121170
fig = plt.figure(figsize=(10, 4))
122171
est = self.leaf_estimators_[node_id]
123-
outer = gridspec.GridSpec(1, 2, wspace=0.25)
172+
outer = gridspec.GridSpec(1, 2, wspace=0.25, width_ratios=[1.2, 1])
124173
inner = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outer[0], wspace=0.1, hspace=0.1, height_ratios=[6, 1])
125174
ax1_main = fig.add_subplot(inner[0])
126175
xgrid = np.linspace(est.shape_fit_.xmin, est.shape_fit_.xmax, 100).reshape([-1, 1])
@@ -139,48 +188,36 @@ def visualize_one_leaf(self, node_id, folder="./results/", name="leaf_sim", save
139188
ax1_density.set_xticks(np.linspace(est.shape_fit_.xmin, est.shape_fit_.xmax, 5))
140189
fig.add_subplot(ax1_density)
141190

142-
ax2 = fig.add_subplot(outer[1])
191+
inner = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer[1], wspace=0.2, hspace=0.1, width_ratios=[1, 1])
192+
ax2_coef = fig.add_subplot(inner[0])
143193
if len(est.beta_) <= 50:
144-
ax2.barh(np.arange(len(est.beta_)), [beta for beta in est.beta_.ravel()][::-1])
145-
ax2.set_yticks(np.arange(len(est.beta_)))
146-
ax2.set_yticklabels([self.feature_names[idx][:8] for idx in range(len(est.beta_.ravel()))][::-1])
147-
ax2.set_xlim(xlim_min, xlim_max)
148-
ax2.set_ylim(-1, len(est.beta_))
149-
ax2.axvline(0, linestyle="dotted", color="black")
194+
ax2_coef.barh(np.arange(len(est.beta_)), [beta for beta in est.beta_.ravel()][::-1])
195+
ax2_coef.set_yticks(np.arange(len(est.beta_)))
196+
ax2_coef.set_yticklabels([self.feature_names[idx] for idx in range(len(est.beta_.ravel()))][::-1])
197+
ax2_coef.set_xlim(xlim_min, xlim_max)
198+
ax2_coef.set_ylim(-1, len(est.beta_))
199+
ax2_coef.axvline(0, linestyle="dotted", color="black")
150200
else:
151201
right = np.round(np.linspace(0, np.round(len(est.beta_) * 0.45).astype(int), 5))
152202
left = len(est.beta_) - 1 - right
153203
input_ticks = np.unique(np.hstack([left, right])).astype(int)
154204

155-
ax2.barh(np.arange(len(est.beta_)), [beta for beta in est.beta_.ravel()][::-1])
156-
ax2.set_yticks(input_ticks)
157-
ax2.set_yticklabels([self.feature_names[idx][:8] for idx in input_ticks][::-1])
158-
ax2.set_xlim(xlim_min, xlim_max)
159-
ax2.set_ylim(-1, len(est.beta_))
160-
ax2.axvline(0, linestyle="dotted", color="black")
161-
162-
ax2title = ""
163-
sortind = np.argsort(np.abs(est.beta_).ravel())[::-1]
164-
for i in range(est.beta_.shape[0]):
165-
if i == 0:
166-
ax2title += str(round(np.abs(est.beta_[sortind[i], 0]), 3)) + self.feature_names[sortind[i]][:8]
167-
continue
168-
elif (i > 0) & (i < 3):
169-
if np.abs(est.beta_[sortind[i], 0]) > 0.001:
170-
if est.beta_[sortind[i], 0] > 0:
171-
ax2title += " + "
172-
else:
173-
ax2title += " - "
174-
ax2title += str(round(np.abs(est.beta_[sortind[i], 0]), 3)) + self.feature_names[sortind[i]][:8]
175-
else:
176-
break
177-
elif i == 3:
178-
if np.abs(est.beta_[sortind[3], 0]) > 0.001:
179-
ax2title += "+..."
180-
else:
181-
break
182-
ax2.set_title(ax2title)
183-
fig.add_subplot(ax2)
205+
ax2_coef.barh(np.arange(len(est.beta_)), [beta for beta in est.beta_.ravel()][::-1])
206+
ax2_coef.set_yticks(input_ticks)
207+
ax2_coef.set_yticklabels([self.feature_names[idx] for idx in input_ticks][::-1])
208+
ax2_coef.set_xlim(xlim_min, xlim_max)
209+
ax2_coef.set_ylim(-1, len(est.beta_))
210+
ax2_coef.axvline(0, linestyle="dotted", color="black")
211+
212+
ax2_coef.set_title("Projection Index")
213+
fig.add_subplot(ax2_coef)
214+
215+
ax2_importance = fig.add_subplot(inner[1])
216+
ax2_coef.get_shared_y_axes().join(ax2_coef, ax2_importance)
217+
ax2_importance.set_yticklabels([])
218+
ax2_importance.barh(self.feature_names, self.get_feature_importance(node_id))
219+
ax2_importance.set_title("Importance")
220+
fig.add_subplot(ax2_importance)
184221
plt.show()
185222
if save_png:
186223
if not os.path.exists(folder):

0 commit comments

Comments
 (0)