Skip to content

Commit e58aace

Browse files
Fix segfault
1 parent 1b39337 commit e58aace

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

bluecast/evaluation/shap_values.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,32 @@
1717
def _is_shap_tree_compatible(model) -> bool:
1818
"""Check if the model is compatible with shap.TreeExplainer.
1919
20-
Known incompatibility: shap's TreeExplainer can segfault on xgb.Booster
21-
objects produced by the native xgb.train() API (as opposed to the sklearn
22-
wrapper). Since a C-level segfault cannot be caught by Python's try/except,
23-
we must detect and skip this case proactively.
20+
TreeExplainer can segfault on certain model types (xgb.Booster from native
21+
xgb.train() API, and potentially other backends). Since a C-level segfault
22+
cannot be caught by Python, we only allow TreeExplainer for CatBoost models
23+
where it is known to be stable. All other models use KernelExplainer.
2424
"""
25+
try:
26+
import catboost
27+
28+
if isinstance(model, catboost.CatBoost):
29+
return True
30+
except ImportError:
31+
pass
32+
2533
try:
2634
import xgboost
2735

28-
# xgb.Booster from xgb.train() is known to segfault in TreeExplainer.
29-
# The sklearn wrapper (XGBClassifier/XGBRegressor) is safe because
30-
# it has get_booster() and shap handles it differently.
36+
# xgb.Booster from xgb.train() segfaults in TreeExplainer
3137
if isinstance(model, xgboost.Booster):
32-
logger.warning(
33-
"Skipping TreeExplainer for xgb.Booster (native API). "
34-
"TreeExplainer can segfault on Booster objects. "
35-
"Falling back to KernelExplainer."
36-
)
3738
return False
38-
39-
xgb_version = tuple(int(x) for x in xgboost.__version__.split(".")[:2])
40-
shap_version = tuple(int(x) for x in shap.__version__.split(".")[:2])
41-
42-
if xgb_version >= (2, 1) and shap_version < (0, 47):
43-
logger.warning(
44-
f"Skipping TreeExplainer: shap {shap.__version__} is known to "
45-
f"segfault with xgboost {xgboost.__version__}. "
46-
f"Upgrade shap to >= 0.47 or use KernelExplainer."
47-
)
39+
# XGBClassifier/XGBRegressor (sklearn API) — also skip to be safe
40+
if isinstance(model, xgboost.XGBModel):
4841
return False
49-
50-
if hasattr(model, "get_booster"):
51-
return True
52-
if hasattr(model, "get_all_params"):
53-
return True
54-
except Exception:
42+
except ImportError:
5543
pass
5644

57-
return True
45+
return False
5846

5947

6048
def shap_explanations(model, df: pd.DataFrame) -> Tuple[np.ndarray, shap.Explainer]:

0 commit comments

Comments
 (0)