Skip to content
7 changes: 6 additions & 1 deletion octopus/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def catboost_classifier() -> ModelConfig:
"""CatBoost classification model config."""
return ModelConfig(
model_class=CatBoostClassifier,
ml_types=[MLType.BINARY, MLType.MULTICLASS],
# KNOWN ISSUE: CatBoost multiclass is disabled because shap.Explainer(model, bg)
# segfaults in SHAP <=0.51 when using TreeExplainer's interventional mode with
# CatBoost multiclass models. Re-enable once SHAP fixes this upstream.
# See datasets_local/specifications_refactorfi/03_shap_catboost_segfault_proposal.md
# Original: ml_types=[MLType.BINARY, MLType.MULTICLASS],
ml_types=[MLType.BINARY],
Comment on lines +155 to +160
Comment on lines +155 to +160
feature_method="internal",
chpo_compatible=True,
scaler=None,
Expand Down
4 changes: 2 additions & 2 deletions octopus/modules/octo/bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def fit(self):
elif self._fi_type == "shap":
self._training.calculate_fi_shap(partition=self._partition)
elif self._fi_type == "permutation":
self._training.calculate_fi_group_permutation(partition=self._partition)
self._training.calculate_fi_permutation(partition=self._partition)
elif self._fi_type == "lofo":
self._training.calculate_fi_lofo()
elif self._fi_type == "constant":
Expand Down Expand Up @@ -576,7 +576,7 @@ def _calculate_fi_sequential(self, fi_type="internal", partition="dev"):
elif fi_type == "shap":
training.calculate_fi_shap(partition=partition)
elif fi_type == "permutation":
training.calculate_fi_group_permutation(partition=partition)
training.calculate_fi_permutation(partition=partition)
elif fi_type == "lofo":
training.calculate_fi_lofo()
elif fi_type == "constant":
Expand Down
397 changes: 110 additions & 287 deletions octopus/modules/octo/training.py

Large diffs are not rendered by default.

Loading
Loading