|
17 | 17 | def _is_shap_tree_compatible(model) -> bool: |
18 | 18 | """Check if the model is compatible with shap.TreeExplainer. |
19 | 19 |
|
20 | | - Known incompatibility: shap's TreeExplainer can segfault on xgb.Booster |
21 | | - objects produced by the native xgb.train() API (as opposed to the sklearn |
22 | | - wrapper). Since a C-level segfault cannot be caught by Python's try/except, |
23 | | - we must detect and skip this case proactively. |
| 20 | + TreeExplainer can segfault on certain model types (xgb.Booster from native |
| 21 | + xgb.train() API, and potentially other backends). Since a C-level segfault |
| 22 | + cannot be caught by Python, we only allow TreeExplainer for CatBoost models |
| 23 | + where it is known to be stable. All other models use KernelExplainer. |
24 | 24 | """ |
| 25 | + try: |
| 26 | + import catboost |
| 27 | + |
| 28 | + if isinstance(model, catboost.CatBoost): |
| 29 | + return True |
| 30 | + except ImportError: |
| 31 | + pass |
| 32 | + |
25 | 33 | try: |
26 | 34 | import xgboost |
27 | 35 |
|
28 | | - # xgb.Booster from xgb.train() is known to segfault in TreeExplainer. |
29 | | - # The sklearn wrapper (XGBClassifier/XGBRegressor) is safe because |
30 | | - # it has get_booster() and shap handles it differently. |
| 36 | + # xgb.Booster from xgb.train() segfaults in TreeExplainer |
31 | 37 | if isinstance(model, xgboost.Booster): |
32 | | - logger.warning( |
33 | | - "Skipping TreeExplainer for xgb.Booster (native API). " |
34 | | - "TreeExplainer can segfault on Booster objects. " |
35 | | - "Falling back to KernelExplainer." |
36 | | - ) |
37 | 38 | return False |
38 | | - |
39 | | - xgb_version = tuple(int(x) for x in xgboost.__version__.split(".")[:2]) |
40 | | - shap_version = tuple(int(x) for x in shap.__version__.split(".")[:2]) |
41 | | - |
42 | | - if xgb_version >= (2, 1) and shap_version < (0, 47): |
43 | | - logger.warning( |
44 | | - f"Skipping TreeExplainer: shap {shap.__version__} is known to " |
45 | | - f"segfault with xgboost {xgboost.__version__}. " |
46 | | - f"Upgrade shap to >= 0.47 or use KernelExplainer." |
47 | | - ) |
| 39 | + # XGBClassifier/XGBRegressor (sklearn API) — also skip to be safe |
| 40 | + if isinstance(model, xgboost.XGBModel): |
48 | 41 | return False |
49 | | - |
50 | | - if hasattr(model, "get_booster"): |
51 | | - return True |
52 | | - if hasattr(model, "get_all_params"): |
53 | | - return True |
54 | | - except Exception: |
| 42 | + except ImportError: |
55 | 43 | pass |
56 | 44 |
|
57 | | - return True |
| 45 | + return False |
58 | 46 |
|
59 | 47 |
|
60 | 48 | def shap_explanations(model, df: pd.DataFrame) -> Tuple[np.ndarray, shap.Explainer]: |
|
0 commit comments