Skip to content
10 changes: 8 additions & 2 deletions examples/analyse_study_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"from octopus.predict import TaskPredictorTest\n",
"from octopus.predict.notebook_utils import (\n",
" display_table,\n",
" find_latest_study,\n",
" show_aucroc_plots,\n",
" show_confusionmatrix,\n",
" show_overall_fi_plot,\n",
Expand Down Expand Up @@ -92,8 +93,10 @@
"metadata": {},
"outputs": [],
"source": [
"# INPUT: Select study\n",
"study_directory = \"../studies/wf_octo_mrmr_octo/\""
"# INPUT: Select study by name prefix (automatically picks the latest timestamped run)\n",
"study_name_prefix = \"wf_octo_mrmr_octo\"\n",
"study_directory = find_latest_study(\"../studies\", study_name_prefix)\n",
"print(f\"Using study: {study_directory}\")"
]
},
{
Expand Down Expand Up @@ -374,6 +377,9 @@
"cell_metadata_filter": "vscode,-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
Expand Down
7 changes: 6 additions & 1 deletion octopus/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def catboost_classifier() -> ModelConfig:
"""CatBoost classification model config."""
return ModelConfig(
model_class=CatBoostClassifier,
ml_types=[MLType.BINARY, MLType.MULTICLASS],
# KNOWN ISSUE: CatBoost multiclass is disabled because shap.Explainer(model, bg)
# segfaults in SHAP <=0.51 when using TreeExplainer's interventional mode with
# CatBoost multiclass models. Re-enable once SHAP fixes this upstream.
# See datasets_local/specifications_refactorfi/03_shap_catboost_segfault_proposal.md
# Original: ml_types=[MLType.BINARY, MLType.MULTICLASS],
ml_types=[MLType.BINARY],
Comment on lines +155 to +160
Comment on lines +155 to +160
feature_method="internal",
chpo_compatible=True,
scaler=None,
Expand Down
4 changes: 2 additions & 2 deletions octopus/modules/octo/bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def fit(self):
elif self._fi_type == "shap":
self._training.calculate_fi_shap(partition=self._partition)
elif self._fi_type == "permutation":
self._training.calculate_fi_group_permutation(partition=self._partition)
self._training.calculate_fi_permutation(partition=self._partition)
elif self._fi_type == "lofo":
self._training.calculate_fi_lofo()
elif self._fi_type == "constant":
Expand Down Expand Up @@ -576,7 +576,7 @@ def _calculate_fi_sequential(self, fi_type="internal", partition="dev"):
elif fi_type == "shap":
training.calculate_fi_shap(partition=partition)
elif fi_type == "permutation":
training.calculate_fi_group_permutation(partition=partition)
training.calculate_fi_permutation(partition=partition)
elif fi_type == "lofo":
training.calculate_fi_lofo()
elif fi_type == "constant":
Expand Down
18 changes: 11 additions & 7 deletions octopus/modules/octo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import copy
import warnings
from typing import TYPE_CHECKING

import optuna
Expand Down Expand Up @@ -353,13 +354,16 @@ def _run_globalhp_optimization(
)

# multivariate sampler with group option
sampler = optuna.samplers.TPESampler(
multivariate=True,
group=True,
constant_liar=True,
seed=self.config.optuna_seed,
n_startup_trials=self.config.n_optuna_startup_trials,
)
# Suppress ExperimentalWarning — these parameters have been stable for 3+ years
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
sampler = optuna.samplers.TPESampler(
multivariate=True,
group=True,
constant_liar=True,
seed=self.config.optuna_seed,
n_startup_trials=self.config.n_optuna_startup_trials,
)

# create study with in-memory storage
study = optuna.create_study(
Expand Down
Loading
Loading