Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
rev: "v1.19.1"
hooks:
- id: mypy
exclude: ^datasets/|examples/|tests/|studies/
exclude: ^datasets/|studies/
additional_dependencies:
- types-networkx
- pandas-stubs
Expand Down
3 changes: 2 additions & 1 deletion examples/basic_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import os

from sklearn.datasets import load_breast_cancer
from sklearn.utils import Bunch

from octopus.modules import Octo
from octopus.study import OctoClassification

### Load and Preprocess Data
breast_cancer = load_breast_cancer(as_frame=True)
breast_cancer: Bunch = load_breast_cancer(as_frame=True) # type: ignore[assignment]

df = breast_cancer["frame"].reset_index()
df.columns = df.columns.str.replace(" ", "_")
Expand Down
3 changes: 2 additions & 1 deletion examples/basic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import os

from sklearn.datasets import load_diabetes
from sklearn.utils import Bunch

from octopus.study import OctoRegression

### Load the diabetes dataset
diabetes = load_diabetes(as_frame=True)
diabetes: Bunch = load_diabetes(as_frame=True) # type: ignore[assignment]

### Create and run OctoRegression
study = OctoRegression(
Expand Down
3 changes: 2 additions & 1 deletion examples/multi_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import os

from sklearn.datasets import load_diabetes
from sklearn.utils import Bunch

from octopus.modules import Mrmr, Octo
from octopus.study import OctoRegression

### Load the diabetes dataset
diabetes = load_diabetes(as_frame=True)
diabetes: Bunch = load_diabetes(as_frame=True) # type: ignore[assignment]

### Create and run OctoRegression with multi-step workflow
study = OctoRegression(
Expand Down
3 changes: 2 additions & 1 deletion examples/use_own_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import os

from sklearn.datasets import load_diabetes
from sklearn.utils import Bunch

from octopus.models.hyperparameter import IntHyperparameter
from octopus.modules import Octo
from octopus.study import OctoRegression

### Load the diabetes dataset
diabetes = load_diabetes(as_frame=True)
diabetes: Bunch = load_diabetes(as_frame=True) # type: ignore[assignment]

### Create and run OctoRegression with custom hyperparameters
study = OctoRegression(
Expand Down
3 changes: 2 additions & 1 deletion examples/wf_multiclass_wine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import os

from sklearn.datasets import load_wine
from sklearn.utils import Bunch

from octopus.modules import Octo
from octopus.study import OctoClassification

### Load and Preprocess Data
wine = load_wine(as_frame=True)
wine: Bunch = load_wine(as_frame=True) # type: ignore[assignment]

df = wine["frame"].reset_index()
df.columns = df.columns.str.replace(" ", "_")
Expand Down
3 changes: 2 additions & 1 deletion examples/wf_roc_octo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os

from sklearn.datasets import load_breast_cancer
from sklearn.utils import Bunch

from octopus.modules import Octo, Roc
from octopus.study import OctoClassification
Expand All @@ -19,7 +20,7 @@
# This is a binary classification dataset with 30 features
# Target: 0 = malignant, 1 = benign

breast_cancer = load_breast_cancer(as_frame=True)
breast_cancer: Bunch = load_breast_cancer(as_frame=True) # type: ignore[assignment]

df = breast_cancer["frame"].reset_index()
df.columns = df.columns.str.replace(" ", "_")
Expand Down
5 changes: 3 additions & 2 deletions octopus/manager/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import os
from collections.abc import Sequence

from attrs import define, field, validators

Expand Down Expand Up @@ -133,8 +134,8 @@ class OctoManager:
study_context: StudyContext = field(validator=[validators.instance_of(StudyContext)])
"""Frozen runtime context containing study configuration."""

workflow: list[Task] = field(validator=[validators.instance_of(list)])
"""List of workflow tasks to execute."""
workflow: Sequence[Task] = field(validator=[validators.instance_of(list)])
"""Workflow tasks to execute."""

outer_parallelization: bool = field(validator=[validators.instance_of(bool)])
"""Whether to run outersplits in parallel."""
Expand Down
6 changes: 5 additions & 1 deletion octopus/manager/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING

import pandas as pd
import ray
Expand All @@ -14,6 +15,9 @@
from octopus.modules import ModuleResult, ResultType, StudyContext, Task
from octopus.utils import calculate_feature_groups, parquet_save

if TYPE_CHECKING:
from collections.abc import Sequence

logger = get_logger()


Expand All @@ -33,7 +37,7 @@ class WorkflowTaskRunner:
"""

study_context: StudyContext = field(validator=[validators.instance_of(StudyContext)])
workflow: list[Task] = field(validator=[validators.instance_of(list)])
workflow: Sequence[Task] = field(validator=[validators.instance_of(list)])
Comment on lines -36 to +40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we have this mismatch: Sequence -- list

cpus_per_outersplit: int = field(validator=[validators.instance_of(int)])

def run(self, outersplit_id: int, outersplit: OuterSplit) -> None:
Expand Down
8 changes: 6 additions & 2 deletions octopus/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def decorator(factory: Callable[[], ModelConfig]) -> Callable[[], ModelConfig]:

return decorator

@classmethod
def get_registered_models(cls) -> list[ModelName]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this function? In the code you normally only want models that fit to your ml_type. Can be a potential error if used somewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make it private by adding a leading underscore, @nihaase ?

"""Get a list of all registered model names."""
return [ModelName(name) for name in cls._config_factories]

@classmethod
def get_config(cls, name: ModelName) -> ModelConfig:
"""Get model configuration by name.
Expand Down Expand Up @@ -185,6 +190,5 @@ def validate_model_compatibility(cls, model_name: ModelName, ml_type: MLType) ->
config = cls.get_config(model_name)
if not config.supports_ml_type(ml_type):
raise ValueError(
f"Model '{model_name}' does not support ml_type '{ml_type.value}'. "
f"Supported types: {', '.join(t.value for t in config.ml_types)}"
f"Model '{model_name}' does not support ml_type '{ml_type.value}'. Supported types: {', '.join(t.value for t in config.ml_types)}"
)
2 changes: 1 addition & 1 deletion octopus/modules/octo/enssel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Ensemble selection."""

# TOBEDONE
# TODO
# - issue: ACC and BALACC need integer pooling values!
# - potential issue: check start_n, +1 or not
# - get FI and counts
Expand Down
4 changes: 2 additions & 2 deletions octopus/modules/octo/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
from octopus.models import ModelName, Models
from octopus.types import MLType

# # TOBEDONE pipeline
# # TODO pipeline
# - implement cat encoding on module level
# - how to provide categorical info to catboost and other models?


logger = get_logger()


class TrainingConfig(TypedDict):
class TrainingConfig(TypedDict, total=False):
"""Training configuration type."""

outl_reduction: int
Expand Down
3 changes: 2 additions & 1 deletion octopus/study/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import platform
from abc import ABC, abstractmethod
from collections.abc import Sequence
from datetime import UTC

import pandas as pd
Expand Down Expand Up @@ -77,7 +78,7 @@ class OctoStudy(ABC):
run_single_outersplit_num: int = field(default=Factory(lambda: -1), validator=[validators.instance_of(int)])
"""Select a single outersplit to execute. Defaults to -1 to run all outersplits"""

workflow: list[Task] = field(
workflow: Sequence[Task] = field(
default=Factory(lambda: [Octo(task_id=0)]),
validator=[validators.instance_of(list), validate_workflow],
)
Expand Down
7 changes: 3 additions & 4 deletions octopus/study/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Validation functions for OctoStudy attributes."""

from collections.abc import Sequence
from typing import TYPE_CHECKING

from attrs import Attribute
Expand All @@ -10,7 +11,7 @@
from octopus.study.core import OctoStudy


def validate_workflow(_instance: "OctoStudy", attribute: Attribute, value: list[Task]) -> None:
def validate_workflow(_instance: "OctoStudy", attribute: Attribute, value: Sequence[Task]) -> None:
"""Validate the `workflow` attribute.

Ensures that the value is a non-empty list where each item is an
Expand Down Expand Up @@ -109,9 +110,7 @@ def validate_workflow(_instance: "OctoStudy", attribute: Attribute, value: list[
if depends_on is not None:
if depends_on not in task_id_to_index:
raise ValueError(
f"Item '{item.description}' (position {idx + 1}) has "
f"'depends_on={depends_on}', which does not"
" correspond to any 'task_id' in the workflow."
f"Item '{item.description}' (position {idx + 1}) has 'depends_on={depends_on}', which does not correspond to any 'task_id' in the workflow."
)
depends_on_idx = task_id_to_index[depends_on]
if depends_on_idx >= idx:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ show_error_codes = true
no_implicit_optional = true
warn_return_any = true
warn_unused_ignores = true
exclude = ["examples", "docs", "tests", "studies"]
exclude = ["docs", "studies"]


[[tool.mypy.overrides]]
Expand Down
2 changes: 1 addition & 1 deletion tests/infrastructure/test_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
def test_parquet_dtype_roundtrip(tmp_path, data, dtype):
"""Test that saving and loading a DataFrame with parquet_save and parquet_load works correctly."""
if dtype == "CategoricalDtype":
dtype = pd.CategoricalDtype(set(data))
dtype = pd.CategoricalDtype(sorted(set(data)))
elif dtype == "StringDtype":
dtype = pd.StringDtype()

Expand Down
2 changes: 1 addition & 1 deletion tests/manager/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_frozen(self):
num_outersplits=4,
)
with pytest.raises(attrs.exceptions.FrozenInstanceError):
config.num_cpus = 8
config.num_cpus = 8 # type: ignore[misc]

def test_create_single_outersplit_gets_all_cpus(self):
"""Test that when running a single outersplit, it gets all CPUs.
Expand Down
31 changes: 11 additions & 20 deletions tests/metrics/test_metrics_uniqueness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from octopus.metrics import Metrics
from octopus.metrics.config import Metric
from octopus.types import ML_TYPES, MLType


Expand All @@ -28,8 +29,7 @@ def test_registry_keys_are_unique(self):
unique_keys = set(registry_keys)

assert len(registry_keys) == len(unique_keys), (
f"Registry keys are not unique. Found {len(registry_keys)} keys but only "
f"{len(unique_keys)} unique keys. Keys: {sorted(registry_keys)}"
f"Registry keys are not unique. Found {len(registry_keys)} keys but only {len(unique_keys)} unique keys. Keys: {sorted(registry_keys)}"
)

def test_metric_config_names_are_unique(self):
Expand All @@ -38,19 +38,14 @@ def test_metric_config_names_are_unique(self):
This is critical for the utils functions that deduce ml_type from metrics.
"""
config_names = []
config_name_to_registry_key = {}
config_name_to_registry_key = defaultdict(list)

for registry_key in self.all_metrics:
try:
config = Metrics.get_instance(registry_key)
config_name = config.name
config_names.append(config_name)

if config_name in config_name_to_registry_key:
config_name_to_registry_key[config_name].append(registry_key)
else:
config_name_to_registry_key[config_name] = [registry_key]

config_name_to_registry_key[config_name].append(registry_key)
except Exception as e:
pytest.fail(f"Failed to get config for metric '{registry_key}': {e}")

Expand Down Expand Up @@ -120,9 +115,9 @@ def test_all_metrics_have_valid_ml_types(self):

# Print distribution for documentation
print("\n=== ML Type Distribution ===")
for ml_type in sorted(ml_type_distribution):
metrics = sorted(ml_type_distribution[ml_type])
print(f"{ml_type} ({len(metrics)}): {metrics}")
for ml_type_s in sorted(ml_type_distribution):
metrics = sorted(ml_type_distribution[ml_type_s])
print(f"{ml_type_s} ({len(metrics)}): {metrics}")

def test_all_metrics_have_valid_prediction_types(self):
"""Test that all metrics have valid prediction_type values."""
Expand Down Expand Up @@ -164,14 +159,12 @@ def test_metrics_loaded_dynamically(self):
missing_types = expected_ml_types - ml_types

assert not missing_types, (
f"Missing expected ML types: {missing_types}. "
f"Found ML types: {sorted(ml_types)}. "
f"This suggests some metric modules may not be imported properly."
f"Missing expected ML types: {missing_types}. Found ML types: {sorted(ml_types)}. This suggests some metric modules may not be imported properly."
)

def test_no_metric_config_attribute_conflicts(self):
"""Test that metric configs don't have conflicting attributes for same names."""
configs_by_name = {}
configs_by_name: dict[str, Metric] = {}
conflicts = []

for registry_key in self.all_metrics:
Expand All @@ -190,14 +183,12 @@ def test_no_metric_config_attribute_conflicts(self):

if existing_config.prediction_type != config.prediction_type:
conflicts.append(
f"'{config_name}': prediction_type conflict - "
f"'{existing_config.prediction_type}' vs '{config.prediction_type}'"
f"'{config_name}': prediction_type conflict - '{existing_config.prediction_type}' vs '{config.prediction_type}'"
)

if existing_config.higher_is_better != config.higher_is_better:
conflicts.append(
f"'{config_name}': higher_is_better conflict - "
f"'{existing_config.higher_is_better}' vs '{config.higher_is_better}'"
f"'{config_name}': higher_is_better conflict - '{existing_config.higher_is_better}' vs '{config.higher_is_better}'"
)
else:
configs_by_name[config_name] = config
Expand Down
Loading
Loading