2929from numpy import random
3030
3131# The maximum number of features for which we can display the full attribution.
32- CAN_DISPLAY_FULL_ATTR = 5
32+ MAX_ATTR_DISP = 5
3333
3434# The maximum number of features for which we can feasibly compute the exact Shapley values.
3535MAX_FEAS_EXACT_FEATS = 9
@@ -52,15 +52,19 @@ def __repr__(self) -> str:
5252 attr_str = ""
5353 coefs_str = ""
5454
55- if len (self .attribution ) <= CAN_DISPLAY_FULL_ATTR :
55+ if len (self .attribution ) <= MAX_ATTR_DISP :
5656 attr_str = "(" + "" .join (f"{ a :.2f} , " for a in self .attribution .flatten ())[:- 2 ] + ")"
5757 coefs_str = "(" + "" .join (f"{ c :.2f} , " for c in self .theta .flatten ())[:- 2 ] + ")"
5858 else :
5959 attr_str = (
60- "(" + "" .join (f"{ a :.2f} , " for a in self .attribution .flatten ()[:5 ])[:- 2 ] + ", ...)"
60+ "("
61+ + "" .join (f"{ a :.2f} , " for a in self .attribution .flatten ()[:MAX_ATTR_DISP ])[:- 2 ]
62+ + ", ...)"
6163 )
6264 coefs_str = (
63- "(" + "" .join (f"{ c :.2f} , " for c in self .theta .flatten ()[:5 ])[:- 2 ] + ", ...)"
65+ "("
66+ + "" .join (f"{ c :.2f} , " for c in self .theta .flatten ()[:MAX_ATTR_DISP ])[:- 2 ]
67+ + ", ...)"
6468 )
6569
6670 return f"""
@@ -83,6 +87,11 @@ def __init__(self, message: str) -> None:
8387 super ().__init__ (self .message )
8488
8589
90+ # TODO(ndevanathan): remove this in the next major update
91+ # This is here for backwards compatibility
92+ SizeIncompatible = SizeIncompatibleError
93+
94+
8695def validate_data (
8796 X_train : np .ndarray ,
8897 X_test : np .ndarray ,
@@ -410,7 +419,7 @@ def error_estimates(rng: random.Generator, cov: np.ndarray) -> tuple[np.ndarray,
410419 p = cov .shape [0 ]
411420 try :
412421 sample_diffs = rng .multivariate_normal (np .zeros (p ), cov , size = 2 ** 10 , method = "cholesky" )
413- except : # noqa: E722
422+ except ( np . linalg . LinAlgError , ValueError ):
414423 sample_diffs = rng .multivariate_normal (np .zeros (p ), cov , size = 2 ** 10 , method = "svd" )
415424 abs_diffs = np .abs (sample_diffs )
416425 norms = np .linalg .norm (sample_diffs , axis = 1 )
0 commit comments