Skip to content

Commit dba8ba1

Browse files
committed
remove fig resizing from FIGS
1 parent 82cdb46 commit dba8ba1

File tree

2 files changed

+99
-65
lines changed

2 files changed

+99
-65
lines changed

imodels/experimental/figs_ensembles.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from imodels.tree.viz_utils import extract_sklearn_tree_from_figs
1414

15-
plt.rcParams['figure.dpi'] = 300
16-
1715

1816
class Node:
1917
def __init__(self, feature: int = None, threshold: int = None,
@@ -45,7 +43,8 @@ def __init__(self, feature: int = None, threshold: int = None,
4543
def update_values(self, X, y):
4644
self.value = y.mean()
4745
if self.threshold is not None:
48-
right_indicator = np.apply_along_axis(lambda x: x[self.feature] > self.threshold, 1, X)
46+
right_indicator = np.apply_along_axis(
47+
lambda x: x[self.feature] > self.threshold, 1, X)
4948
X_right = X[right_indicator, :]
5049
X_left = X[~right_indicator, :]
5150
y_right = y[right_indicator]
@@ -61,9 +60,11 @@ def shrink(self, reg_param, cum_sum=0):
6160
if self.left is None: # if leaf node, change prediction
6261
self.value = cum_sum
6362
else:
64-
shrunk_diff = (self.left.value - self.value) / (1 + reg_param / self.n_samples)
63+
shrunk_diff = (self.left.value - self.value) / \
64+
(1 + reg_param / self.n_samples)
6565
self.left.shrink(reg_param, cum_sum + shrunk_diff)
66-
shrunk_diff = (self.right.value - self.value) / (1 + reg_param / self.n_samples)
66+
shrunk_diff = (self.right.value - self.value) / \
67+
(1 + reg_param / self.n_samples)
6768
self.right.shrink(reg_param, cum_sum + shrunk_diff)
6869

6970
def setattrs(self, **kwargs):
@@ -132,7 +133,7 @@ def _init_decision_function(self):
132133
"""
133134
# used by sklearn GridSearchCV, BaggingClassifier
134135
if self.prediction_task == 'classification':
135-
decision_function = lambda x: self.predict_proba(x)[:, 1]
136+
def decision_function(x): return self.predict_proba(x)[:, 1]
136137
elif self.prediction_task == 'regression':
137138
decision_function = self.predict
138139

@@ -166,7 +167,8 @@ def _construct_node_linear(self, X, y, idxs, tree_num=0, sample_weight=None):
166167
feature=None, threshold=None,
167168
impurity_reduction=-1, split_or_linear='split') # leaf node that just returns its value
168169
else:
169-
assert isinstance(best_linear_coef, float), 'coef should be a float'
170+
assert isinstance(best_linear_coef,
171+
float), 'coef should be a float'
170172
return Node(idxs=idxs, value=best_linear_coef, tree_num=tree_num,
171173
feature=best_feature, threshold=None,
172174
impurity_reduction=impurity_reduction, split_or_linear='linear')
@@ -178,7 +180,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
178180
RIGHT = 2
179181

180182
# fit stump
181-
stump = tree.DecisionTreeRegressor(max_depth=1, max_features=max_features)
183+
stump = tree.DecisionTreeRegressor(
184+
max_depth=1, max_features=max_features)
182185
if sample_weight is not None:
183186
sample_weight = sample_weight[idxs]
184187
stump.fit(X[idxs], y[idxs], sample_weight=sample_weight)
@@ -201,10 +204,10 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
201204

202205
# split node
203206
impurity_reduction = (
204-
impurity[SPLIT] -
205-
impurity[LEFT] * n_node_samples[LEFT] / n_node_samples[SPLIT] -
206-
impurity[RIGHT] * n_node_samples[RIGHT] / n_node_samples[SPLIT]
207-
) * idxs.sum()
207+
impurity[SPLIT] -
208+
impurity[LEFT] * n_node_samples[LEFT] / n_node_samples[SPLIT] -
209+
impurity[RIGHT] * n_node_samples[RIGHT] / n_node_samples[SPLIT]
210+
) * idxs.sum()
208211

209212
node_split = Node(idxs=idxs, value=value[SPLIT], tree_num=tree_num,
210213
feature=feature[SPLIT], threshold=threshold[SPLIT],
@@ -216,7 +219,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
216219
idxs_left = idxs_split & idxs
217220
idxs_right = ~idxs_split & idxs
218221
node_left = Node(idxs=idxs_left, value=value[LEFT], tree_num=tree_num)
219-
node_right = Node(idxs=idxs_right, value=value[RIGHT], tree_num=tree_num)
222+
node_right = Node(
223+
idxs=idxs_right, value=value[RIGHT], tree_num=tree_num)
220224
node_split.setattrs(left_temp=node_left, right_temp=node_right, )
221225
return node_split
222226

@@ -231,7 +235,8 @@ def fit(self, X, y=None, feature_names=None, verbose=False, sample_weight=None):
231235
"""
232236

233237
if self.prediction_task == 'classification':
234-
self.classes_, y = np.unique(y, return_inverse=True) # deals with str inputs
238+
self.classes_, y = np.unique(
239+
y, return_inverse=True) # deals with str inputs
235240
X, y = check_X_y(X, y)
236241
y = y.astype(float)
237242
if feature_names is not None:
@@ -252,7 +257,8 @@ def _update_tree_preds(n_iter):
252257
if not tree_num_2_ == tree_num_:
253258
y_residuals_per_tree[tree_num_] -= y_predictions_per_tree[tree_num_2_]
254259
tree_.update_values(X, y_residuals_per_tree[tree_num_])
255-
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[tree_num_], X)
260+
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[
261+
tree_num_], X)
256262

257263
# set up initial potential_splits
258264
# everything in potential_splits either is_root (so it can be added directly to self.trees_)
@@ -267,13 +273,15 @@ def _update_tree_preds(n_iter):
267273
potential_splits.append(node_init_linear)
268274
for node in potential_splits:
269275
node.setattrs(is_root=True)
270-
potential_splits = sorted(potential_splits, key=lambda x: x.impurity_reduction)
276+
potential_splits = sorted(
277+
potential_splits, key=lambda x: x.impurity_reduction)
271278

272279
# start the greedy fitting algorithm
273280
finished = False
274281
while len(potential_splits) > 0 and not finished:
275282
# print('potential_splits', [str(s) for s in potential_splits])
276-
split_node = potential_splits.pop() # get node with max impurity_reduction (since it's sorted)
283+
# get node with max impurity_reduction (since it's sorted)
284+
split_node = potential_splits.pop()
277285

278286
# don't split on node
279287
if split_node.impurity_reduction < self.min_impurity_decrease:
@@ -304,16 +312,19 @@ def _update_tree_preds(n_iter):
304312
if split_node.split_or_linear == 'split':
305313
# assign left_temp, right_temp to be proper children
306314
# (basically adds them to tree in predict method)
307-
split_node.setattrs(left=split_node.left_temp, right=split_node.right_temp)
315+
split_node.setattrs(left=split_node.left_temp,
316+
right=split_node.right_temp)
308317

309318
# add children to potential_splits
310319
potential_splits.append(split_node.left)
311320
potential_splits.append(split_node.right)
312321

313322
# update predictions for altered tree
314323
for tree_num_ in range(len(self.trees_)):
315-
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[tree_num_], X)
316-
y_predictions_per_tree[-1] = np.zeros(X.shape[0]) # dummy 0 preds for possible new trees
324+
y_predictions_per_tree[tree_num_] = self._predict_tree(self.trees_[
325+
tree_num_], X)
326+
# dummy 0 preds for possible new trees
327+
y_predictions_per_tree[-1] = np.zeros(X.shape[0])
317328

318329
# update residuals for each tree
319330
# -1 is key for potential new tree
@@ -352,7 +363,8 @@ def _update_tree_preds(n_iter):
352363
)
353364
elif potential_split.split_or_linear == 'linear':
354365
assert potential_split.is_root, 'Currently, linear node only supported as root'
355-
assert potential_split.idxs.sum() == X.shape[0], 'Currently, linear node only supported as root'
366+
assert potential_split.idxs.sum(
367+
) == X.shape[0], 'Currently, linear node only supported as root'
356368
potential_split_updated = self._construct_node_linear(idxs=potential_split.idxs,
357369
X=X,
358370
y=y_target,
@@ -371,7 +383,8 @@ def _update_tree_preds(n_iter):
371383
potential_splits_new.append(potential_split)
372384

373385
# sort so largest impurity reduction comes last (should probs make this a heap later)
374-
potential_splits = sorted(potential_splits_new, key=lambda x: x.impurity_reduction)
386+
potential_splits = sorted(
387+
potential_splits_new, key=lambda x: x.impurity_reduction)
375388
if verbose:
376389
print(self)
377390
if self.max_rules is not None and self.complexity_ >= self.max_rules:
@@ -383,9 +396,11 @@ def _update_tree_preds(n_iter):
383396
# potentially fit linear model on the tree preds
384397
if self.posthoc_ridge:
385398
if self.prediction_task == 'regression':
386-
self.weighted_model_ = RidgeCV(alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
399+
self.weighted_model_ = RidgeCV(
400+
alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
387401
elif self.prediction_task == 'classification':
388-
self.weighted_model_ = RidgeClassifierCV(alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
402+
self.weighted_model_ = RidgeClassifierCV(
403+
alphas=(0.01, 0.1, 0.5, 1.0, 5, 10))
389404
X_feats = self._extract_tree_predictions(X)
390405
self.weighted_model_.fit(X_feats, y)
391406
return self
@@ -402,7 +417,8 @@ def _tree_to_str(self, root: Node, prefix=''):
402417
pprefix)
403418

404419
def __str__(self):
405-
s = '------------\n' + '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
420+
s = '------------\n' + \
421+
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
406422
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
407423
for i in range(len(self.feature_names_))[::-1]:
408424
s = s.replace(f'X_{i}', self.feature_names_[i])
@@ -425,14 +441,16 @@ def predict_proba(self, X):
425441
return NotImplemented
426442
elif self.posthoc_ridge and self.weighted_model_: # note, during fitting don't use the weighted moel
427443
X_feats = self._extract_tree_predictions(X)
428-
d = self.weighted_model_.decision_function(X_feats) # for 2 classes, this (n_samples,)
444+
d = self.weighted_model_.decision_function(
445+
X_feats) # for 2 classes, this (n_samples,)
429446
probs = np.exp(d) / (1 + np.exp(d))
430447
return np.vstack((1 - probs, probs)).transpose()
431448
else:
432449
preds = np.zeros(X.shape[0])
433450
for tree in self.trees_:
434451
preds += self._predict_tree(tree, X)
435-
preds = np.clip(preds, a_min=0., a_max=1.) # constrain to range of probabilities
452+
# constrain to range of probabilities
453+
preds = np.clip(preds, a_min=0., a_max=1.)
436454
return np.vstack((1 - preds, preds)).transpose()
437455

438456
def _extract_tree_predictions(self, X):
@@ -473,7 +491,7 @@ def _predict_tree_single_point(root: Node, x):
473491

474492
def plot(self, cols=2, feature_names=None, filename=None, label="all",
475493
impurity=False, tree_number=None, dpi=150, fig_size=None):
476-
is_single_tree = len(self.trees_) < 2 or tree_number is not None
494+
is_single_tree = len(self.trees_) < 2 or tree_number is not None
477495
n_cols = int(cols)
478496
n_rows = int(np.ceil(len(self.trees_) / n_cols))
479497
# if is_single_tree:
@@ -486,7 +504,7 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
486504
fig.set_size_inches(fig_size, fig_size)
487505
criterion = "squared_error" if self.prediction_task == "regression" else "gini"
488506
n_classes = 1 if self.prediction_task == 'regression' else 2
489-
ax_size = int(len(self.trees_))#n_cols * n_rows
507+
ax_size = int(len(self.trees_)) # n_cols * n_rows
490508
for i in range(n_plots):
491509
r = i // n_cols
492510
c = i % n_cols
@@ -496,8 +514,10 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
496514
else:
497515
ax = axs
498516
try:
499-
dt = extract_sklearn_tree_from_figs(self, i if tree_number is None else tree_number, n_classes)
500-
plot_tree(dt, ax=ax, feature_names=feature_names, label=label, impurity=impurity)
517+
dt = extract_sklearn_tree_from_figs(
518+
self, i if tree_number is None else tree_number, n_classes)
519+
plot_tree(dt, ax=ax, feature_names=feature_names,
520+
label=label, impurity=impurity)
501521
except IndexError:
502522
ax.axis('off')
503523
continue

0 commit comments

Comments
 (0)