Skip to content

Commit 90eeefc

Browse files
committed
copilot fixes
1 parent 526a140 commit 90eeefc

File tree

5 files changed

+11
-23
lines changed

5 files changed

+11
-23
lines changed

octopus/predict/feature_importance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ def calculate_fi_permutation(
442442
test_data: Dict mapping outersplit_id to test DataFrame.
443443
train_data: Dict mapping outersplit_id to train DataFrame.
444444
target_assignments: Dict mapping semantic target roles to column
445-
names. For single-target tasks: ``{"default": "target_col"}``.
446-
For time-to-event: ``{"duration": "...", "event": "..."}``.
445+
names. For single-target tasks: ``{"default": "y"}``.
446+
For time-to-event: ``{"duration": "time_col", "event": "event_col"}``.
447447
target_metric: Metric name for scoring.
448448
positive_class: Positive class label for classification.
449449
n_repeats: Number of permutation repeats per feature.

tests/modules/octo/test_classes_attribute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_classification_model_has_classes_attribute(sample_data, model_name):
9696
training = Training(
9797
training_id=f"test_{model_name}",
9898
ml_type=MLType.BINARY,
99-
target_assignments={"target": "target"},
99+
target_assignments={"default": "target"},
100100
feature_cols=["x1", "x2"],
101101
row_id_col="row_id",
102102
data_train=train,

tests/modules/octo/test_column_ordering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ def _create_training(
139139
) -> Training:
140140
"""Create a Training instance."""
141141
if ml_type == MLType.REGRESSION:
142-
target_assignments = {"target": "target_reg"}
142+
target_assignments = {"default": "target_reg"}
143143
target_metric = "R2"
144144
else:
145-
target_assignments = {"target": "target_class"}
145+
target_assignments = {"default": "target_class"}
146146
target_metric = "AUCROC"
147147

148148
ml_model_params = _get_default_model_params(model_name)

tests/modules/octo/test_model_fitted_validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ def get_model_configs():
8989
return {
9090
MLType.BINARY: {
9191
"models": available_models[MLType.BINARY],
92-
"target_assignments": {"target": "target_class"},
92+
"target_assignments": {"default": "target_class"},
9393
"target_metric": "accuracy",
9494
},
9595
MLType.REGRESSION: {
9696
"models": available_models[MLType.REGRESSION],
97-
"target_assignments": {"target": "target_reg"},
97+
"target_assignments": {"default": "target_reg"},
9898
"target_metric": "mse",
9999
},
100100
MLType.TIMETOEVENT: {

tests/modules/test_training_feature_importances.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,39 +72,27 @@
7272

7373
ML_TYPE_CONFIGS = {
7474
MLType.BINARY: {
75-
"target_assignments": {"target": "target_class"},
75+
"target_assignments": {"default": "target_class"},
7676
"target_metric": "AUCROC",
7777
},
7878
MLType.REGRESSION: {
79-
"target_assignments": {"target": "target_reg"},
79+
"target_assignments": {"default": "target_reg"},
8080
"target_metric": "R2",
8181
},
8282
MLType.TIMETOEVENT: {
8383
"target_assignments": {"duration": "duration", "event": "event"},
8484
"target_metric": "CI",
8585
},
8686
MLType.MULTICLASS: {
87-
"target_assignments": {"target": "target_multiclass"},
87+
"target_assignments": {"default": "target_multiclass"},
8888
"target_metric": "ACCBAL_MC",
8989
},
9090
}
9191

9292

9393
def _get_available_models_by_type():
9494
"""Get all available models dynamically from the registry, grouped by ML type."""
95-
all_models = Models._config_factories.keys()
96-
models_by_type = {ml_type: [] for ml_type in MLType}
97-
98-
for model_name in all_models:
99-
try:
100-
model_config = Models.get_config(model_name)
101-
for ml_type in MLType:
102-
if model_config.supports_ml_type(ml_type):
103-
models_by_type[ml_type].append(model_name)
104-
except Exception:
105-
continue
106-
107-
return models_by_type
95+
return {ml_type: list(Models.get_models_for_type(ml_type)) for ml_type in MLType}
10896

10997

11098
def _generate_model_fi_params():

0 commit comments

Comments
 (0)