-
Notifications
You must be signed in to change notification settings - Fork 0
enable mypy in examples and tests dirs #361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import json | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import pandas as pd | ||
| import ray | ||
|
|
@@ -14,6 +15,9 @@ | |
| from octopus.modules import ModuleResult, ResultType, StudyContext, Task | ||
| from octopus.utils import calculate_feature_groups, parquet_save | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Sequence | ||
dasmy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| logger = get_logger() | ||
|
|
||
|
|
||
|
|
@@ -33,7 +37,7 @@ class WorkflowTaskRunner: | |
| """ | ||
|
|
||
| study_context: StudyContext = field(validator=[validators.instance_of(StudyContext)]) | ||
| workflow: list[Task] = field(validator=[validators.instance_of(list)]) | ||
| workflow: Sequence[Task] = field(validator=[validators.instance_of(list)]) | ||
|
Comment on lines
-36
to
+40
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why we have this mismatch: Sequence -- list |
||
| cpus_per_outersplit: int = field(validator=[validators.instance_of(int)]) | ||
dasmy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def run(self, outersplit_id: int, outersplit: OuterSplit) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,11 @@ def decorator(factory: Callable[[], ModelConfig]) -> Callable[[], ModelConfig]: | |
|
|
||
| return decorator | ||
|
|
||
| @classmethod | ||
| def get_registered_models(cls) -> list[ModelName]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this function? In the code you normally only want models that fit to your ml_type. Can be a potential error if used somewhere.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make it private by adding a leading underscore, @nihaase ? |
||
| """Get a list of all registered model names.""" | ||
| return [ModelName(name) for name in cls._config_factories] | ||
|
|
||
| @classmethod | ||
| def get_config(cls, name: ModelName) -> ModelConfig: | ||
| """Get model configuration by name. | ||
|
|
@@ -185,6 +190,5 @@ def validate_model_compatibility(cls, model_name: ModelName, ml_type: MLType) -> | |
| config = cls.get_config(model_name) | ||
| if not config.supports_ml_type(ml_type): | ||
| raise ValueError( | ||
| f"Model '{model_name}' does not support ml_type '{ml_type.value}'. " | ||
| f"Supported types: {', '.join(t.value for t in config.ml_types)}" | ||
| f"Model '{model_name}' does not support ml_type '{ml_type.value}'. Supported types: {', '.join(t.value for t in config.ml_types)}" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.