Skip to content

Commit ea4b8ba

Browse files
committed
copilot fixes
1 parent fb3d4f7 commit ea4b8ba

File tree

3 files changed

+89
-29
lines changed

3 files changed

+89
-29
lines changed

octopus/modules/octo/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,10 @@ def calculate_fi_permutation(
476476

477477
# Build training pool for draw-from-pool permutation.
478478
# Use x_train_processed (larger, more representative) as the sampling pool.
479-
train_pool = pd.concat([self.x_train_processed, self.data_train[target_cols]], axis=1)
479+
# Align targets to x_train_processed index to handle outl_reduction > 0,
480+
# where self.data_train retains all rows but x_train_processed has outliers removed.
481+
train_targets = self.data_train.loc[self.x_train_processed.index, target_cols]
482+
train_pool = pd.concat([self.x_train_processed, train_targets], axis=1)
480483

481484
feature_groups = self.feature_groups if use_groups else None
482485

@@ -635,7 +638,8 @@ def calculate_fi_shap(self, partition: str = "dev", shap_type: str = "kernel", b
635638
# Construct background set based on `background_size` for kernel SHAP
636639
if shap_type == "kernel" and background_size is not None:
637640
if len(data) > background_size:
638-
indices = np.random.choice(len(data), size=background_size, replace=False)
641+
rng = np.random.default_rng(42)
642+
indices = rng.choice(len(data), size=background_size, replace=False)
639643
X_background = data.iloc[indices]
640644
else:
641645
X_background = data

octopus/predict/feature_importance.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,16 +318,19 @@ def compute_permutation_single(
318318
"""Compute permutation feature importance for a single model.
319319
320320
Uses the custom draw-from-pool algorithm: replacement values are drawn
321-
from *X_train* (larger, more representative pool) rather than shuffling
322-
within the test set.
321+
from the combined pool of *X_train* and *X_test* feature values,
322+
providing a more representative approximation of the marginal
323+
distribution than either partition alone.
323324
324325
When ``feature_groups`` is provided, computes importance for **both**
325326
individual features and feature groups.
326327
327328
Args:
328329
model: A fitted model.
329330
X_test: Test data (must contain feature columns + target columns).
330-
X_train: Training data used as the sampling pool.
331+
X_train: Training data. Feature values from both *X_train* and
332+
*X_test* are combined to form the sampling pool for replacement
333+
values.
331334
feature_cols: Feature column names used by the model.
332335
target_metric: Metric name for scoring.
333336
target_assignments: Dict mapping target roles to column names.
@@ -342,39 +345,56 @@ def compute_permutation_single(
342345
Column ``"importance"`` holds the mean across repeats (compatible
343346
with Bag aggregation).
344347
"""
345-
rng = np.random.RandomState(random_state)
348+
rng = np.random.default_rng(random_state)
346349

347350
# Baseline score
348351
baseline = get_score_from_model(
349352
model, X_test, feature_cols, target_metric, target_assignments, positive_class=positive_class
350353
)
351354

355+
# O(1) membership checks for feature_cols
356+
feature_cols_set = set(feature_cols)
357+
352358
# Build items to permute: individual features + groups
353359
items_to_permute: list[tuple[str, list[str]]] = [(f, [f]) for f in feature_cols]
354360
if feature_groups is not None:
355361
for group_name, group_features in feature_groups.items():
356-
if any(f in feature_cols for f in group_features):
362+
if any(f in feature_cols_set for f in group_features):
357363
items_to_permute.append((group_name, group_features))
358364

359365
results: list[dict[str, Any]] = []
366+
n_test = len(X_test)
360367

361368
for item_name, cols_to_permute in items_to_permute:
362-
active_cols = [c for c in cols_to_permute if c in feature_cols and c in X_test.columns]
369+
active_cols = [c for c in cols_to_permute if c in feature_cols_set and c in X_test.columns]
363370
if not active_cols:
364371
continue
365372

373+
# Precompute combined pool values (X_train + X_test) per column once
374+
# to avoid repeated array conversions inside the permutation repeat loop.
375+
pool_values_per_col = {
376+
col: np.concatenate([np.asarray(X_train[col].values), np.asarray(X_test[col].values)])
377+
for col in active_cols
378+
}
379+
380+
# Single DataFrame copy; permuted columns are restored after each repeat
381+
test_shuffled = X_test.copy()
382+
originals = {col: test_shuffled[col].values.copy() for col in active_cols}
383+
366384
repeat_scores: list[float] = []
367385
for _ in range(n_repeats):
368-
test_shuffled = X_test.copy()
369386
for col in active_cols:
370-
train_values = np.asarray(X_train[col].values)
371-
test_shuffled[col] = rng.choice(train_values, size=len(test_shuffled), replace=True)
387+
test_shuffled[col] = rng.choice(pool_values_per_col[col], size=n_test, replace=True)
372388

373389
perm_score = get_score_from_model(
374390
model, test_shuffled, feature_cols, target_metric, target_assignments, positive_class=positive_class
375391
)
376392
repeat_scores.append(baseline - perm_score)
377393

394+
# Restore original column values for the next repeat
395+
for col in active_cols:
396+
test_shuffled[col] = originals[col]
397+
378398
stats = compute_per_repeat_stats(repeat_scores)
379399
results.append({"feature": item_name, **stats})
380400

@@ -411,12 +431,21 @@ def calculate_fi_permutation(
411431
receive zero importance for that split, ensuring the result covers the
412432
union of all input features.
413433
434+
.. note::
435+
436+
This Layer 2 orchestrator only supports **single-target** tasks
437+
(binary classification, multiclass, regression). For time-to-event
438+
tasks (which require multi-key ``target_assignments`` like
439+
``{"duration": ..., "event": ...}``), use the Layer 1 primitive
440+
``compute_permutation_single`` directly with full
441+
``target_assignments``.
442+
414443
Args:
415444
models: Dict mapping outersplit_id to fitted model.
416445
selected_features: Dict mapping outersplit_id to feature list.
417446
test_data: Dict mapping outersplit_id to test DataFrame.
418447
train_data: Dict mapping outersplit_id to train DataFrame.
419-
target_col: Target column name.
448+
target_col: Target column name (single-target tasks only).
420449
target_metric: Metric name for scoring.
421450
positive_class: Positive class label for classification.
422451
n_repeats: Number of permutation repeats per feature.

uv.lock

Lines changed: 44 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)