Skip to content

Commit eb90dc0

Browse files
committed
copilot fixes
1 parent 986efb5 commit eb90dc0

File tree

3 files changed

+28
-25
lines changed

3 files changed

+28
-25
lines changed

octopus/modules/octo/training.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,15 @@ def _relabel_processed_output(
217217
# In this case, column order is preserved (no ColumnTransformer reordering)
218218
output_cols = list(self.feature_cols)
219219

220+
n_cols = processed_data.shape[1]
220221
if set(output_cols) != set(self.feature_cols):
222+
# If column count also mismatches, raise a clear error
223+
if n_cols != len(self.feature_cols):
224+
raise ValueError(
225+
f"Pipeline output has {n_cols} columns but expected {len(self.feature_cols)}. "
226+
f"Pipeline columns: {output_cols}, expected: {list(self.feature_cols)}. "
227+
f"This may indicate extra/unexpected columns were passed to the transformer."
228+
)
221229
logger.warning(
222230
"Pipeline output columns %s do not match feature_cols %s. Falling back to positional labeling.",
223231
output_cols,
@@ -773,7 +781,7 @@ def calculate_fi_featuresused_shap(self, partition="dev", bg_max=200):
773781
else:
774782
feature_names = [f"f{i}" for i in range(n_features)]
775783

776-
# Build predict function that converts numpy to DataFrame for sklearn compatibility
784+
# Build predict function as fallback for sklearn compatibility
777785
_feature_cols = self.feature_cols
778786

779787
if getattr(self, "ml_type", None) in (MLType.BINARY, MLType.MULTICLASS) and hasattr(
@@ -787,14 +795,18 @@ def predict_fn(X):
787795
def predict_fn(X):
788796
return np.asarray(self.model.predict(pd.DataFrame(np.asarray(X), columns=_feature_cols)))
789797

790-
# Build explainer
798+
# Build explainer: try model directly first for fast Tree/Linear explainers,
799+
# fall back to callable wrapper if that fails
800+
X_bg_df = pd.DataFrame(X_bg, columns=_feature_cols)
801+
X_eval_df = pd.DataFrame(X_eval, columns=_feature_cols)
802+
791803
try:
792-
# Let SHAP auto-select the best explainer (Tree for tree models, Kernel otherwise)
793-
explainer = shap.Explainer(predict_fn, X_bg)
794-
sv = explainer(X_eval)
804+
# Try model directly — SHAP can auto-detect Tree/Linear explainers for speed
805+
explainer = shap.Explainer(self.model, X_bg_df)
806+
sv = explainer(X_eval_df)
795807
except Exception as e1:
796-
logger.debug(f"SHAP auto explainer failed: {e1}. Falling back to callable + Kernel.")
797-
# Use the generic constructor so SHAP picks Kernel with the given background
808+
logger.debug(f"SHAP auto explainer with model failed: {e1}. Falling back to callable wrapper.")
809+
# Fall back to callable approach (always works, but uses slower KernelExplainer)
798810
explainer = shap.Explainer(predict_fn, X_bg)
799811
sv = explainer(X_eval)
800812

@@ -911,8 +923,8 @@ def predict(self, x: pd.DataFrame) -> np.ndarray:
911923
if isinstance(x, np.ndarray):
912924
x = pd.DataFrame(x, columns=self.feature_cols)
913925
elif isinstance(x, pd.DataFrame):
914-
# Reset index to avoid sklearn ColumnTransformer issues
915-
x = x.reset_index(drop=True)
926+
# Subset to feature_cols to prevent extra columns flowing through ColumnTransformer
927+
x = x[self.feature_cols].reset_index(drop=True)
916928

917929
# Apply the same preprocessing pipeline used during training
918930
x_processed = self._transform_to_dataframe(x)
@@ -931,8 +943,8 @@ def predict_proba(self, x: pd.DataFrame) -> np.ndarray:
931943
if isinstance(x, np.ndarray):
932944
x = pd.DataFrame(x, columns=self.feature_cols)
933945
elif isinstance(x, pd.DataFrame):
934-
# Reset index to avoid sklearn ColumnTransformer issues
935-
x = x.reset_index(drop=True)
946+
# Subset to feature_cols to prevent extra columns flowing through ColumnTransformer
947+
x = x[self.feature_cols].reset_index(drop=True)
936948

937949
# Apply the same preprocessing pipeline used during training
938950
x_processed = self._transform_to_dataframe(x)

octopus/predict/notebook_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import re
1011
from typing import TYPE_CHECKING, Any
1112

1213
import numpy as np
@@ -57,8 +58,9 @@ def find_latest_study(studies_root: str | UPath, prefix: str) -> str:
5758
"""
5859
root = UPath(studies_root)
5960
# Match timestamped directories: prefix-YYYYMMDD_HHMMSS
61+
_timestamp_pattern = re.compile(re.escape(prefix) + r"-\d{8}_\d{6}$")
6062
candidates = sorted(
61-
[d for d in root.glob(f"{prefix}-*") if d.is_dir()],
63+
[d for d in root.glob(f"{prefix}-*") if d.is_dir() and _timestamp_pattern.search(d.name)],
6264
key=lambda p: p.name,
6365
reverse=True,
6466
)

tests/modules/octo/test_column_ordering.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
pytest tests/modules/octo/test_column_ordering.py -v
99
"""
1010

11-
import warnings
12-
1311
import numpy as np
1412
import pandas as pd
13+
import pytest
1514

1615
from octopus.models import Models
1716
from octopus.models.hyperparameter import (
@@ -171,12 +170,12 @@ def _create_training(
171170
)
172171

173172

173+
@pytest.mark.filterwarnings("ignore")
174174
class TestColumnOrdering:
175175
"""Tests for ColumnTransformer column ordering with mixed types."""
176176

177177
def test_x_train_processed_columns_match_feature_cols(self):
178178
"""Verify x_train_processed has columns in feature_cols order after fit."""
179-
warnings.filterwarnings("ignore")
180179
data, feature_cols, feature_groups = _create_mixed_type_data()
181180
data_train, data_dev, data_test = _split_data(data)
182181

@@ -190,7 +189,6 @@ def test_x_train_processed_columns_match_feature_cols(self):
190189

191190
def test_x_dev_processed_columns_match_feature_cols(self):
192191
"""Verify x_dev_processed has columns in feature_cols order after fit."""
193-
warnings.filterwarnings("ignore")
194192
data, feature_cols, feature_groups = _create_mixed_type_data()
195193
data_train, data_dev, data_test = _split_data(data)
196194

@@ -201,7 +199,6 @@ def test_x_dev_processed_columns_match_feature_cols(self):
201199

202200
def test_x_test_processed_columns_match_feature_cols(self):
203201
"""Verify x_test_processed has columns in feature_cols order after fit."""
204-
warnings.filterwarnings("ignore")
205202
data, feature_cols, feature_groups = _create_mixed_type_data()
206203
data_train, data_dev, data_test = _split_data(data)
207204

@@ -212,7 +209,6 @@ def test_x_test_processed_columns_match_feature_cols(self):
212209

213210
def test_numerical_data_in_numerical_column(self):
214211
"""Verify that numerical columns in x_train_processed contain actual numerical data."""
215-
warnings.filterwarnings("ignore")
216212
data, feature_cols, feature_groups = _create_mixed_type_data()
217213
data_train, data_dev, data_test = _split_data(data)
218214

@@ -228,7 +224,6 @@ def test_numerical_data_in_numerical_column(self):
228224

229225
def test_categorical_data_in_categorical_column(self):
230226
"""Verify that categorical columns in x_train_processed contain actual categorical data."""
231-
warnings.filterwarnings("ignore")
232227
data, feature_cols, feature_groups = _create_mixed_type_data()
233228
data_train, data_dev, data_test = _split_data(data)
234229

@@ -247,7 +242,6 @@ def test_internal_fi_labels_correct_with_mixed_types(self):
247242
248243
Target is strongly correlated with num1, so num1 should have highest importance.
249244
"""
250-
warnings.filterwarnings("ignore")
251245
data, feature_cols, feature_groups = _create_mixed_type_data(n_samples=500)
252246
data_train, data_dev, data_test = _split_data(data)
253247

@@ -267,7 +261,6 @@ def test_internal_fi_labels_correct_with_mixed_types(self):
267261

268262
def test_permutation_fi_labels_correct_with_mixed_types(self):
269263
"""Verify permutation FI labels are correct when mixed column types exist."""
270-
warnings.filterwarnings("ignore")
271264
data, feature_cols, feature_groups = _create_mixed_type_data(n_samples=500)
272265
data_train, data_dev, data_test = _split_data(data)
273266

@@ -287,7 +280,6 @@ def test_permutation_fi_labels_correct_with_mixed_types(self):
287280

288281
def test_all_numerical_columns_no_regression(self):
289282
"""Verify all-numerical columns still work correctly (regression test)."""
290-
warnings.filterwarnings("ignore")
291283
data, feature_cols, feature_groups = _create_numerical_only_data()
292284
data_train, data_dev, data_test = _split_data(data)
293285

@@ -304,7 +296,6 @@ def test_all_numerical_columns_no_regression(self):
304296

305297
def test_predict_works_with_mixed_types(self):
306298
"""Verify predict() works correctly with mixed column types."""
307-
warnings.filterwarnings("ignore")
308299
data, feature_cols, feature_groups = _create_mixed_type_data()
309300
data_train, data_dev, data_test = _split_data(data)
310301

@@ -317,7 +308,6 @@ def test_predict_works_with_mixed_types(self):
317308

318309
def test_predict_classification_with_mixed_types(self):
319310
"""Verify predict_proba() works correctly with mixed column types for classification."""
320-
warnings.filterwarnings("ignore")
321311
data, feature_cols, feature_groups = _create_mixed_type_data()
322312
data_train, data_dev, data_test = _split_data(data)
323313

@@ -341,7 +331,6 @@ def test_predict_classification_with_mixed_types(self):
341331

342332
def test_relabel_fallback_when_get_feature_names_out_fails(self):
343333
"""Verify fallback when get_feature_names_out() is not available."""
344-
warnings.filterwarnings("ignore")
345334
data, feature_cols, feature_groups = _create_numerical_only_data()
346335
data_train, data_dev, data_test = _split_data(data)
347336

0 commit comments

Comments
 (0)