Skip to content

Commit 526a140

Browse files
committed
ml_type fix
1 parent add5291 commit 526a140

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

octopus/predict/feature_importance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def calculate_fi_shap(
561561
shap_type: str = "kernel",
562562
max_samples: int = 100,
563563
background_size: int = 200,
564+
ml_type: MLType | None = None,
564565
) -> pd.DataFrame:
565566
"""Calculate SHAP feature importance across all outer splits.
566567
@@ -575,6 +576,8 @@ def calculate_fi_shap(
575576
``'exact'``).
576577
max_samples: Maximum number of evaluation samples per split.
577578
background_size: Maximum background dataset size for kernel explainer.
579+
ml_type: ML task type. Passed through to ``compute_shap_single``
580+
to correctly choose ``predict`` vs ``predict_proba``.
578581
579582
Returns:
580583
DataFrame with columns: fi_source, feature, importance_mean,
@@ -606,6 +609,7 @@ def calculate_fi_shap(
606609
X_background=X_background,
607610
max_samples=max_samples,
608611
threshold_ratio=None, # Keep all for predict
612+
ml_type=ml_type,
609613
)
610614

611615
for _, row in fi_df.iterrows():

octopus/predict/task_predictor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def calculate_fi(
438438
models=self._models,
439439
selected_features=self._selected_features,
440440
test_data=test_data,
441+
ml_type=self.ml_type,
441442
**kwargs,
442443
)
443444
else:

octopus/predict/task_predictor_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def calculate_fi( # type: ignore[override]
272272
models=self._models,
273273
selected_features=self._selected_features,
274274
test_data=self._test_data,
275+
ml_type=self.ml_type,
275276
**kwargs,
276277
)
277278
else:

0 commit comments

Comments
 (0)