1212
1313from imodels .tree .viz_utils import extract_sklearn_tree_from_figs
1414
15- plt .rcParams ['figure.dpi' ] = 300
16-
1715
1816class Node :
1917 def __init__ (self , feature : int = None , threshold : int = None ,
@@ -45,7 +43,8 @@ def __init__(self, feature: int = None, threshold: int = None,
4543 def update_values (self , X , y ):
4644 self .value = y .mean ()
4745 if self .threshold is not None :
48- right_indicator = np .apply_along_axis (lambda x : x [self .feature ] > self .threshold , 1 , X )
46+ right_indicator = np .apply_along_axis (
47+ lambda x : x [self .feature ] > self .threshold , 1 , X )
4948 X_right = X [right_indicator , :]
5049 X_left = X [~ right_indicator , :]
5150 y_right = y [right_indicator ]
@@ -61,9 +60,11 @@ def shrink(self, reg_param, cum_sum=0):
6160 if self .left is None : # if leaf node, change prediction
6261 self .value = cum_sum
6362 else :
64- shrunk_diff = (self .left .value - self .value ) / (1 + reg_param / self .n_samples )
63+ shrunk_diff = (self .left .value - self .value ) / \
64+ (1 + reg_param / self .n_samples )
6565 self .left .shrink (reg_param , cum_sum + shrunk_diff )
66- shrunk_diff = (self .right .value - self .value ) / (1 + reg_param / self .n_samples )
66+ shrunk_diff = (self .right .value - self .value ) / \
67+ (1 + reg_param / self .n_samples )
6768 self .right .shrink (reg_param , cum_sum + shrunk_diff )
6869
6970 def setattrs (self , ** kwargs ):
@@ -132,7 +133,7 @@ def _init_decision_function(self):
132133 """
133134 # used by sklearn GridSearchCV, BaggingClassifier
134135 if self .prediction_task == 'classification' :
135- decision_function = lambda x : self .predict_proba (x )[:, 1 ]
136+ def decision_function ( x ): return self .predict_proba (x )[:, 1 ]
136137 elif self .prediction_task == 'regression' :
137138 decision_function = self .predict
138139
@@ -166,7 +167,8 @@ def _construct_node_linear(self, X, y, idxs, tree_num=0, sample_weight=None):
166167 feature = None , threshold = None ,
167168 impurity_reduction = - 1 , split_or_linear = 'split' ) # leaf node that just returns its value
168169 else :
169- assert isinstance (best_linear_coef , float ), 'coef should be a float'
170+ assert isinstance (best_linear_coef ,
171+ float ), 'coef should be a float'
170172 return Node (idxs = idxs , value = best_linear_coef , tree_num = tree_num ,
171173 feature = best_feature , threshold = None ,
172174 impurity_reduction = impurity_reduction , split_or_linear = 'linear' )
@@ -178,7 +180,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
178180 RIGHT = 2
179181
180182 # fit stump
181- stump = tree .DecisionTreeRegressor (max_depth = 1 , max_features = max_features )
183+ stump = tree .DecisionTreeRegressor (
184+ max_depth = 1 , max_features = max_features )
182185 if sample_weight is not None :
183186 sample_weight = sample_weight [idxs ]
184187 stump .fit (X [idxs ], y [idxs ], sample_weight = sample_weight )
@@ -201,10 +204,10 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
201204
202205 # split node
203206 impurity_reduction = (
204- impurity [SPLIT ] -
205- impurity [LEFT ] * n_node_samples [LEFT ] / n_node_samples [SPLIT ] -
206- impurity [RIGHT ] * n_node_samples [RIGHT ] / n_node_samples [SPLIT ]
207- ) * idxs .sum ()
207+ impurity [SPLIT ] -
208+ impurity [LEFT ] * n_node_samples [LEFT ] / n_node_samples [SPLIT ] -
209+ impurity [RIGHT ] * n_node_samples [RIGHT ] / n_node_samples [SPLIT ]
210+ ) * idxs .sum ()
208211
209212 node_split = Node (idxs = idxs , value = value [SPLIT ], tree_num = tree_num ,
210213 feature = feature [SPLIT ], threshold = threshold [SPLIT ],
@@ -216,7 +219,8 @@ def _construct_node_with_stump(self, X, y, idxs, tree_num, sample_weight=None, m
216219 idxs_left = idxs_split & idxs
217220 idxs_right = ~ idxs_split & idxs
218221 node_left = Node (idxs = idxs_left , value = value [LEFT ], tree_num = tree_num )
219- node_right = Node (idxs = idxs_right , value = value [RIGHT ], tree_num = tree_num )
222+ node_right = Node (
223+ idxs = idxs_right , value = value [RIGHT ], tree_num = tree_num )
220224 node_split .setattrs (left_temp = node_left , right_temp = node_right , )
221225 return node_split
222226
@@ -231,7 +235,8 @@ def fit(self, X, y=None, feature_names=None, verbose=False, sample_weight=None):
231235 """
232236
233237 if self .prediction_task == 'classification' :
234- self .classes_ , y = np .unique (y , return_inverse = True ) # deals with str inputs
238+ self .classes_ , y = np .unique (
239+ y , return_inverse = True ) # deals with str inputs
235240 X , y = check_X_y (X , y )
236241 y = y .astype (float )
237242 if feature_names is not None :
@@ -252,7 +257,8 @@ def _update_tree_preds(n_iter):
252257 if not tree_num_2_ == tree_num_ :
253258 y_residuals_per_tree [tree_num_ ] -= y_predictions_per_tree [tree_num_2_ ]
254259 tree_ .update_values (X , y_residuals_per_tree [tree_num_ ])
255- y_predictions_per_tree [tree_num_ ] = self ._predict_tree (self .trees_ [tree_num_ ], X )
260+ y_predictions_per_tree [tree_num_ ] = self ._predict_tree (self .trees_ [
261+ tree_num_ ], X )
256262
257263 # set up initial potential_splits
258264 # everything in potential_splits either is_root (so it can be added directly to self.trees_)
@@ -267,13 +273,15 @@ def _update_tree_preds(n_iter):
267273 potential_splits .append (node_init_linear )
268274 for node in potential_splits :
269275 node .setattrs (is_root = True )
270- potential_splits = sorted (potential_splits , key = lambda x : x .impurity_reduction )
276+ potential_splits = sorted (
277+ potential_splits , key = lambda x : x .impurity_reduction )
271278
272279 # start the greedy fitting algorithm
273280 finished = False
274281 while len (potential_splits ) > 0 and not finished :
275282 # print('potential_splits', [str(s) for s in potential_splits])
276- split_node = potential_splits .pop () # get node with max impurity_reduction (since it's sorted)
283+ # get node with max impurity_reduction (since it's sorted)
284+ split_node = potential_splits .pop ()
277285
278286 # don't split on node
279287 if split_node .impurity_reduction < self .min_impurity_decrease :
@@ -304,16 +312,19 @@ def _update_tree_preds(n_iter):
304312 if split_node .split_or_linear == 'split' :
305313 # assign left_temp, right_temp to be proper children
306314 # (basically adds them to tree in predict method)
307- split_node .setattrs (left = split_node .left_temp , right = split_node .right_temp )
315+ split_node .setattrs (left = split_node .left_temp ,
316+ right = split_node .right_temp )
308317
309318 # add children to potential_splits
310319 potential_splits .append (split_node .left )
311320 potential_splits .append (split_node .right )
312321
313322 # update predictions for altered tree
314323 for tree_num_ in range (len (self .trees_ )):
315- y_predictions_per_tree [tree_num_ ] = self ._predict_tree (self .trees_ [tree_num_ ], X )
316- y_predictions_per_tree [- 1 ] = np .zeros (X .shape [0 ]) # dummy 0 preds for possible new trees
324+ y_predictions_per_tree [tree_num_ ] = self ._predict_tree (self .trees_ [
325+ tree_num_ ], X )
326+ # dummy 0 preds for possible new trees
327+ y_predictions_per_tree [- 1 ] = np .zeros (X .shape [0 ])
317328
318329 # update residuals for each tree
319330 # -1 is key for potential new tree
@@ -352,7 +363,8 @@ def _update_tree_preds(n_iter):
352363 )
353364 elif potential_split .split_or_linear == 'linear' :
354365 assert potential_split .is_root , 'Currently, linear node only supported as root'
355- assert potential_split .idxs .sum () == X .shape [0 ], 'Currently, linear node only supported as root'
366+ assert potential_split .idxs .sum (
367+ ) == X .shape [0 ], 'Currently, linear node only supported as root'
356368 potential_split_updated = self ._construct_node_linear (idxs = potential_split .idxs ,
357369 X = X ,
358370 y = y_target ,
@@ -371,7 +383,8 @@ def _update_tree_preds(n_iter):
371383 potential_splits_new .append (potential_split )
372384
373385 # sort so largest impurity reduction comes last (should probs make this a heap later)
374- potential_splits = sorted (potential_splits_new , key = lambda x : x .impurity_reduction )
386+ potential_splits = sorted (
387+ potential_splits_new , key = lambda x : x .impurity_reduction )
375388 if verbose :
376389 print (self )
377390 if self .max_rules is not None and self .complexity_ >= self .max_rules :
@@ -383,9 +396,11 @@ def _update_tree_preds(n_iter):
383396 # potentially fit linear model on the tree preds
384397 if self .posthoc_ridge :
385398 if self .prediction_task == 'regression' :
386- self .weighted_model_ = RidgeCV (alphas = (0.01 , 0.1 , 0.5 , 1.0 , 5 , 10 ))
399+ self .weighted_model_ = RidgeCV (
400+ alphas = (0.01 , 0.1 , 0.5 , 1.0 , 5 , 10 ))
387401 elif self .prediction_task == 'classification' :
388- self .weighted_model_ = RidgeClassifierCV (alphas = (0.01 , 0.1 , 0.5 , 1.0 , 5 , 10 ))
402+ self .weighted_model_ = RidgeClassifierCV (
403+ alphas = (0.01 , 0.1 , 0.5 , 1.0 , 5 , 10 ))
389404 X_feats = self ._extract_tree_predictions (X )
390405 self .weighted_model_ .fit (X_feats , y )
391406 return self
@@ -402,7 +417,8 @@ def _tree_to_str(self, root: Node, prefix=''):
402417 pprefix )
403418
404419 def __str__ (self ):
405- s = '------------\n ' + '\n \t +\n ' .join ([self ._tree_to_str (t ) for t in self .trees_ ])
420+ s = '------------\n ' + \
421+ '\n \t +\n ' .join ([self ._tree_to_str (t ) for t in self .trees_ ])
406422 if hasattr (self , 'feature_names_' ) and self .feature_names_ is not None :
407423 for i in range (len (self .feature_names_ ))[::- 1 ]:
408424 s = s .replace (f'X_{ i } ' , self .feature_names_ [i ])
@@ -425,14 +441,16 @@ def predict_proba(self, X):
425441 return NotImplemented
426442 elif self .posthoc_ridge and self .weighted_model_ : # note, during fitting don't use the weighted moel
427443 X_feats = self ._extract_tree_predictions (X )
428- d = self .weighted_model_ .decision_function (X_feats ) # for 2 classes, this (n_samples,)
444+ d = self .weighted_model_ .decision_function (
445+ X_feats ) # for 2 classes, this (n_samples,)
429446 probs = np .exp (d ) / (1 + np .exp (d ))
430447 return np .vstack ((1 - probs , probs )).transpose ()
431448 else :
432449 preds = np .zeros (X .shape [0 ])
433450 for tree in self .trees_ :
434451 preds += self ._predict_tree (tree , X )
435- preds = np .clip (preds , a_min = 0. , a_max = 1. ) # constrain to range of probabilities
452+ # constrain to range of probabilities
453+ preds = np .clip (preds , a_min = 0. , a_max = 1. )
436454 return np .vstack ((1 - preds , preds )).transpose ()
437455
438456 def _extract_tree_predictions (self , X ):
@@ -473,7 +491,7 @@ def _predict_tree_single_point(root: Node, x):
473491
474492 def plot (self , cols = 2 , feature_names = None , filename = None , label = "all" ,
475493 impurity = False , tree_number = None , dpi = 150 , fig_size = None ):
476- is_single_tree = len (self .trees_ ) < 2 or tree_number is not None
494+ is_single_tree = len (self .trees_ ) < 2 or tree_number is not None
477495 n_cols = int (cols )
478496 n_rows = int (np .ceil (len (self .trees_ ) / n_cols ))
479497 # if is_single_tree:
@@ -486,7 +504,7 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
486504 fig .set_size_inches (fig_size , fig_size )
487505 criterion = "squared_error" if self .prediction_task == "regression" else "gini"
488506 n_classes = 1 if self .prediction_task == 'regression' else 2
489- ax_size = int (len (self .trees_ ))# n_cols * n_rows
507+ ax_size = int (len (self .trees_ )) # n_cols * n_rows
490508 for i in range (n_plots ):
491509 r = i // n_cols
492510 c = i % n_cols
@@ -496,8 +514,10 @@ def plot(self, cols=2, feature_names=None, filename=None, label="all",
496514 else :
497515 ax = axs
498516 try :
499- dt = extract_sklearn_tree_from_figs (self , i if tree_number is None else tree_number , n_classes )
500- plot_tree (dt , ax = ax , feature_names = feature_names , label = label , impurity = impurity )
517+ dt = extract_sklearn_tree_from_figs (
518+ self , i if tree_number is None else tree_number , n_classes )
519+ plot_tree (dt , ax = ax , feature_names = feature_names ,
520+ label = label , impurity = impurity )
501521 except IndexError :
502522 ax .axis ('off' )
503523 continue
0 commit comments