55import numpy as np
66import pandas as pd
77from scipy .special import expit
8- import sklearn
98from sklearn import datasets
109from sklearn import tree
1110from sklearn .base import BaseEstimator , ClassifierMixin , RegressorMixin
@@ -52,17 +51,13 @@ def setattrs(self, **kwargs):
5251 setattr (self , k , v )
5352
5453 def __str__ (self ):
55- try :
56- sklearn .utils .validation .check_is_fitted (self )
57- if self .is_root :
58- return f'X_{ self .feature } <= { self .threshold :0.3f} (Tree #{ self .tree_num } root)'
59- elif self .left is None and self .right is None :
60- return f'Val: { self .value [0 ][0 ]:0.3f} (leaf)'
61- else :
62- return f'X_{ self .feature } <= { self .threshold :0.3f} (split)'
63- except ValueError :
64- return self .__class__ .__name__
65-
54+ if self .is_root :
55+ return f'X_{ self .feature } <= { self .threshold :0.3f} (Tree #{ self .tree_num } root)'
56+ elif self .left is None and self .right is None :
57+ return f'Val: { self .value [0 ][0 ]:0.3f} (leaf)'
58+ else :
59+ return f'X_{ self .feature } <= { self .threshold :0.3f} (split)'
60+
6661 def print_root (self , y ):
6762 try :
6863 one_count = pd .Series (y ).value_counts ()[1.0 ]
@@ -77,6 +72,8 @@ def print_root(self, y):
7772 else :
7873 return f'X_{ self .feature } <= { self .threshold :0.3f} ' + one_proportion
7974
75+ def __repr__ (self ):
76+ return self .__str__ ()
8077
8178
8279class FIGS (BaseEstimator ):
@@ -414,21 +411,17 @@ def _tree_to_str_with_data(self, X, y, root: Node, prefix=''):
414411 self ._tree_to_str_with_data (X [~ left ], y [~ left ], root .right , pprefix ))
415412
416413 def __str__ (self ):
417- try :
418- sklearn .utils .validation .check_is_fitted (self )
419- s = '> ------------------------------\n '
420- s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n '
421- s += '> \t Predictions are made by summing the "Val" reached by traversing each tree.\n '
422- s += '> \t For classifiers, a sigmoid function is then applied to the sum.\n '
423- s += '> ------------------------------\n '
424- s += '\n \t +\n ' .join ([self ._tree_to_str (t ) for t in self .trees_ ])
425- if hasattr (self , 'feature_names_' ) and self .feature_names_ is not None :
426- for i in range (len (self .feature_names_ ))[::- 1 ]:
427- s = s .replace (f'X_{ i } ' , self .feature_names_ [i ])
428- return s
429- except ValueError :
430- return self .__class__ .__name__
431-
414+ s = '> ------------------------------\n '
415+ s += '> FIGS-Fast Interpretable Greedy-Tree Sums:\n '
416+ s += '> \t Predictions are made by summing the "Val" reached by traversing each tree.\n '
417+ s += '> \t For classifiers, a sigmoid function is then applied to the sum.\n '
418+ s += '> ------------------------------\n '
419+ s += '\n \t +\n ' .join ([self ._tree_to_str (t ) for t in self .trees_ ])
420+ if hasattr (self , 'feature_names_' ) and self .feature_names_ is not None :
421+ for i in range (len (self .feature_names_ ))[::- 1 ]:
422+ s = s .replace (f'X_{ i } ' , self .feature_names_ [i ])
423+ return s
424+
432425 def print_tree (self , X , y , feature_names = None ):
433426 s = '------------\n ' + \
434427 '\n \t +\n ' .join ([self ._tree_to_str_with_data (X , y , t )
0 commit comments