Skip to content

Commit f221ba5

Browse files
authored
Merge pull request #184 from csinva/revert-183-master
Revert "Update hierarchical_shrinkage, fix bugs, change attribute name"
2 parents 7ab6510 + d96c3f2 commit f221ba5

File tree

8 files changed

+267
-280
lines changed

8 files changed

+267
-280
lines changed

imodels/experimental/figs_ensembles.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
from matplotlib import pyplot as plt
5-
import sklearn
65
from sklearn import datasets
76
from sklearn import tree
87
from sklearn.base import BaseEstimator
@@ -73,22 +72,18 @@ def setattrs(self, **kwargs):
7372
setattr(self, k, v)
7473

7574
def __str__(self):
76-
try:
77-
sklearn.utils.validation.check_is_fitted(self)
78-
if self.split_or_linear == 'linear':
79-
if self.is_root:
80-
return f'X_{self.feature} * {self.value:0.3f} (Tree #{self.tree_num} linear root)'
81-
else:
82-
return f'X_{self.feature} * {self.value:0.3f} (linear)'
75+
if self.split_or_linear == 'linear':
76+
if self.is_root:
77+
return f'X_{self.feature} * {self.value:0.3f} (Tree #{self.tree_num} linear root)'
8378
else:
84-
if self.is_root:
85-
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
86-
elif self.left is None and self.right is None:
87-
return f'Val: {self.value[0][0]:0.3f} (leaf)'
88-
else:
89-
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
90-
except ValueError:
91-
return self.__class__.__name__
79+
return f'X_{self.feature} * {self.value:0.3f} (linear)'
80+
else:
81+
if self.is_root:
82+
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
83+
elif self.left is None and self.right is None:
84+
return f'Val: {self.value[0][0]:0.3f} (leaf)'
85+
else:
86+
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
9287

9388
def __repr__(self):
9489
return self.__str__()
@@ -422,17 +417,13 @@ def _tree_to_str(self, root: Node, prefix=''):
422417
pprefix)
423418

424419
def __str__(self):
425-
try:
426-
sklearn.utils.validation.check_is_fitted(self)
427-
s = '------------\n' + \
428-
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
429-
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
430-
for i in range(len(self.feature_names_))[::-1]:
431-
s = s.replace(f'X_{i}', self.feature_names_[i])
432-
return s
433-
except ValueError:
434-
return self.__class__.__name__
435-
420+
s = '------------\n' + \
421+
'\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
422+
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
423+
for i in range(len(self.feature_names_))[::-1]:
424+
s = s.replace(f'X_{i}', self.feature_names_[i])
425+
return s
426+
436427
def predict(self, X):
437428
if self.posthoc_ridge and self.weighted_model_: # note, during fitting don't use the weighted moel
438429
X_feats = self._extract_tree_predictions(X)

imodels/rule_list/corels_wrapper.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import pandas as pd
7-
import sklearn
87
from sklearn.preprocessing import KBinsDiscretizer
98

109
from imodels.rule_list.greedy_rule_list import GreedyRuleListClassifier
@@ -234,18 +233,14 @@ def _traverse_rule(self, X: np.ndarray, y: np.ndarray, feature_names: List[str],
234233
self.str_print = str_print
235234

236235
def __str__(self):
237-
try:
238-
sklearn.utils.validation.check_is_fitted(self)
239-
if corels_supported:
240-
if self.str_print is not None:
241-
return 'OptimalRuleList:\n\n' + self.str_print
242-
else:
243-
return 'OptimalRuleList:\n\n' + self.rl_.__str__()
236+
if corels_supported:
237+
if self.str_print is not None:
238+
return 'OptimalRuleList:\n\n' + self.str_print
244239
else:
245-
return super().__str__()
246-
except ValueError:
247-
return self.__class__.__name__
248-
240+
return 'OptimalRuleList:\n\n' + self.rl_.__str__()
241+
else:
242+
return super().__str__()
243+
249244
def _get_complexity(self):
250245
return sum([len(corule['antecedents']) for corule in self.rl_.rules])
251246

imodels/rule_list/greedy_rule_list.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from copy import deepcopy
99

1010
import numpy as np
11-
import sklearn
1211
from sklearn.base import BaseEstimator, ClassifierMixin
1312
from sklearn.utils.multiclass import unique_labels
1413
from sklearn.utils.validation import check_array, check_is_fitted
@@ -141,43 +140,48 @@ def predict(self, X):
141140
X = check_array(X)
142141
return np.argmax(self.predict_proba(X), axis=1)
143142

143+
"""
144+
def __str__(self):
145+
# s = ''
146+
# for rule in self.rules_:
147+
# s += f"mean {rule['val'].round(3)} ({rule['num_pts']} pts)\n"
148+
# if 'col' in rule:
149+
# s += f"if {rule['col']} >= {rule['cutoff']} then {rule['val_right'].round(3)} ({rule['num_pts_right']} pts)\n"
150+
# return s
151+
"""
144152

145153
def __str__(self):
146154
'''Print out the list in a nice way
147155
'''
148-
try:
149-
sklearn.utils.validation.check_is_fitted(self)
150-
s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
151-
152-
def red(s):
153-
# return f"\033[91m{s}\033[00m"
154-
return s
155-
156-
def cyan(s):
157-
# return f"\033[96m{s}\033[00m"
158-
return s
159-
160-
def rule_name(rule):
161-
if rule['flip']:
162-
return '~' + rule['col']
163-
return rule['col']
164-
165-
# rule = self.rules_[0]
166-
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
167-
for rule in self.rules_:
168-
s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
169-
# s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
170-
if 'col' in rule:
171-
# prefix = f"if {rule['col']} >= {rule['cutoff']}"
172-
prefix = f"if {rule_name(rule)}"
173-
val = f"{100 * rule['val_right'].round(3)}"
174-
s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
175-
# rule = self.rules_[-1]
176-
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
156+
s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
157+
158+
def red(s):
159+
# return f"\033[91m{s}\033[00m"
160+
return s
161+
162+
def cyan(s):
163+
# return f"\033[96m{s}\033[00m"
177164
return s
178-
except ValueError:
179-
return self.__class__.__name__
180-
165+
166+
def rule_name(rule):
167+
if rule['flip']:
168+
return '~' + rule['col']
169+
return rule['col']
170+
171+
# rule = self.rules_[0]
172+
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
173+
for rule in self.rules_:
174+
s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
175+
# s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
176+
if 'col' in rule:
177+
# prefix = f"if {rule['col']} >= {rule['cutoff']}"
178+
prefix = f"if {rule_name(rule)}"
179+
val = f"{100 * rule['val_right'].round(3)}"
180+
s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
181+
# rule = self.rules_[-1]
182+
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
183+
return s
184+
181185
######## HERE ONWARDS CUSTOM SPLITTING (DEPRECATED IN FAVOR OF SKLEARN STUMP) ########
182186
######################################################################################
183187
def _find_best_split(self, x, y):

imodels/rule_set/brs.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from numpy.random import random
1919
from pandas import read_csv
2020
from scipy.sparse import csc_matrix
21-
import sklearn
2221
from sklearn.base import BaseEstimator, ClassifierMixin
2322
from sklearn.ensemble import RandomForestClassifier
2423
from sklearn.utils.multiclass import check_classification_targets
@@ -193,12 +192,8 @@ def fit(self, X, y, feature_names: list = None, init=[], verbose=False):
193192
return self
194193

195194
def __str__(self):
196-
try:
197-
sklearn.utils.validation.check_is_fitted(self)
198-
return ' '.join(str(r) for r in self.rules_)
199-
except ValueError:
200-
return self.__class__.__name__
201-
195+
return ' '.join(str(r) for r in self.rules_)
196+
202197
def predict(self, X):
203198
check_is_fitted(self)
204199
if isinstance(X, np.ndarray):

imodels/rule_set/rule_fit.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pandas as pd
1414
import scipy
1515
from scipy.special import softmax
16-
import sklearn
1716
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1817
from sklearn.base import TransformerMixin
1918
from sklearn.utils.multiclass import unique_labels
@@ -243,16 +242,12 @@ def visualize(self, decimals=2):
243242
return rules[['rule', 'coef']].round(decimals)
244243

245244
def __str__(self):
246-
try:
247-
sklearn.utils.validation.check_is_fitted(self)
248-
s = '> ------------------------------\n'
249-
s += '> RuleFit:\n'
250-
s += '> \tPredictions are made by summing the coefficients of each rule\n'
251-
s += '> ------------------------------\n'
252-
return s + self.visualize().to_string(index=False) + '\n'
253-
except ValueError:
254-
return self.__class__.__name__
255-
245+
s = '> ------------------------------\n'
246+
s += '> RuleFit:\n'
247+
s += '> \tPredictions are made by summing the coefficients of each rule\n'
248+
s += '> ------------------------------\n'
249+
return s + self.visualize().to_string(index=False) + '\n'
250+
256251
def _extract_rules(self, X, y) -> List[str]:
257252
return extract_rulefit(X, y,
258253
feature_names=self.feature_placeholders,

imodels/tree/cart_wrapper.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This is just a simple wrapper around sklearn decisiontree
22
# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
33

4-
import sklearn
54
from sklearn.tree import DecisionTreeClassifier, export_text, DecisionTreeRegressor
65
from imodels.util.arguments import check_fit_arguments
76

@@ -49,18 +48,15 @@ def _set_complexity(self):
4948
self.complexity_ = compute_tree_complexity(self.tree_)
5049

5150
def __str__(self):
52-
try:
53-
sklearn.utils.validation.check_is_fitted(self)
54-
s = '> ------------------------------\n'
55-
s += '> Greedy CART Tree:\n'
56-
s += '> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n'
57-
s += '> ------------------------------' + '\n'
58-
if hasattr(self, 'feature_names') and self.feature_names is not None:
59-
return s + export_text(self, feature_names=self.feature_names, show_weights=True)
60-
else:
61-
return s + export_text(self, show_weights=True)
62-
except ValueError:
63-
return self.__class__.__name__
51+
s = '> ------------------------------\n'
52+
s += '> Greedy CART Tree:\n'
53+
s += '> \tPrediction is made by looking at the value in the appropriate leaf of the tree\n'
54+
s += '> ------------------------------' + '\n'
55+
if hasattr(self, 'feature_names') and self.feature_names is not None:
56+
return s + export_text(self, feature_names=self.feature_names, show_weights=True)
57+
else:
58+
return s + export_text(self, show_weights=True)
59+
6460

6561
class GreedyTreeRegressor(DecisionTreeRegressor):
6662
"""Wrapper around sklearn greedy tree regressor
@@ -102,11 +98,7 @@ def _set_complexity(self):
10298
self.complexity_ = compute_tree_complexity(self.tree_)
10399

104100
def __str__(self):
105-
try:
106-
sklearn.utils.validation.check_is_fitted(self)
107-
if hasattr(self, 'feature_names') and self.feature_names is not None:
108-
return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True)
109-
else:
110-
return 'GreedyTree:\n' + export_text(self, show_weights=True)
111-
except ValueError:
112-
return self.__class__.__name__
101+
if hasattr(self, 'feature_names') and self.feature_names is not None:
102+
return 'GreedyTree:\n' + export_text(self, feature_names=self.feature_names, show_weights=True)
103+
else:
104+
return 'GreedyTree:\n' + export_text(self, show_weights=True)

imodels/tree/figs.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import pandas as pd
77
from scipy.special import expit
8-
import sklearn
98
from sklearn import datasets
109
from sklearn import tree
1110
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
@@ -52,17 +51,13 @@ def setattrs(self, **kwargs):
5251
setattr(self, k, v)
5352

5453
def __str__(self):
55-
try:
56-
sklearn.utils.validation.check_is_fitted(self)
57-
if self.is_root:
58-
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
59-
elif self.left is None and self.right is None:
60-
return f'Val: {self.value[0][0]:0.3f} (leaf)'
61-
else:
62-
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
63-
except ValueError:
64-
return self.__class__.__name__
65-
54+
if self.is_root:
55+
return f'X_{self.feature} <= {self.threshold:0.3f} (Tree #{self.tree_num} root)'
56+
elif self.left is None and self.right is None:
57+
return f'Val: {self.value[0][0]:0.3f} (leaf)'
58+
else:
59+
return f'X_{self.feature} <= {self.threshold:0.3f} (split)'
60+
6661
def print_root(self, y):
6762
try:
6863
one_count = pd.Series(y).value_counts()[1.0]
@@ -77,6 +72,8 @@ def print_root(self, y):
7772
else:
7873
return f'X_{self.feature} <= {self.threshold:0.3f}' + one_proportion
7974

75+
def __repr__(self):
76+
return self.__str__()
8077

8178

8279
class FIGS(BaseEstimator):
@@ -414,21 +411,17 @@ def _tree_to_str_with_data(self, X, y, root: Node, prefix=''):
414411
self._tree_to_str_with_data(X[~left], y[~left], root.right, pprefix))
415412

416413
def __str__(self):
417-
try:
418-
sklearn.utils.validation.check_is_fitted(self)
419-
s = '> ------------------------------\n'
420-
s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n'
421-
s += '> \tPredictions are made by summing the "Val" reached by traversing each tree.\n'
422-
s += '> \tFor classifiers, a sigmoid function is then applied to the sum.\n'
423-
s += '> ------------------------------\n'
424-
s += '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
425-
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
426-
for i in range(len(self.feature_names_))[::-1]:
427-
s = s.replace(f'X_{i}', self.feature_names_[i])
428-
return s
429-
except ValueError:
430-
return self.__class__.__name__
431-
414+
s = '> ------------------------------\n'
415+
s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n'
416+
s += '> \tPredictions are made by summing the "Val" reached by traversing each tree.\n'
417+
s += '> \tFor classifiers, a sigmoid function is then applied to the sum.\n'
418+
s += '> ------------------------------\n'
419+
s += '\n\t+\n'.join([self._tree_to_str(t) for t in self.trees_])
420+
if hasattr(self, 'feature_names_') and self.feature_names_ is not None:
421+
for i in range(len(self.feature_names_))[::-1]:
422+
s = s.replace(f'X_{i}', self.feature_names_[i])
423+
return s
424+
432425
def print_tree(self, X, y, feature_names=None):
433426
s = '------------\n' + \
434427
'\n\t+\n'.join([self._tree_to_str_with_data(X, y, t)

0 commit comments

Comments
 (0)