Skip to content

Commit 70ed53a

Browse files
authored
update mypy config (#145)
* update mypy config * fix lint * remove any * lint * add to node optimizer * try unreachable * fix mypy
1 parent 865f0ed commit 70ed53a

File tree

9 files changed

+23
-31
lines changed

9 files changed

+23
-31
lines changed

autointent/_callbacks/tensorboard.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,12 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5757
for key, value in module_kwargs.items():
5858
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
5959

60-
def log_value(self, **kwargs: dict[str, Any]) -> None:
60+
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
6161
"""
6262
Log data.
6363
6464
:param kwargs: Data to log.
6565
"""
66-
if self.module_writer is None:
67-
msg = "start_run must be called before log_value."
68-
raise RuntimeError(msg)
69-
7066
for key, value in kwargs.items():
7167
if isinstance(value, int | float):
7268
self.module_writer.add_scalar(key, value)
@@ -79,10 +75,6 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
7975
8076
:param metrics: Metrics to log.
8177
"""
82-
if self.module_writer is None:
83-
msg = "start_run must be called before log_value."
84-
raise RuntimeError(msg)
85-
8678
for key, value in metrics.items():
8779
if isinstance(value, int | float):
8880
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]

autointent/_pipeline/_pipeline.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import yaml
10+
from typing_extensions import assert_never
1011

1112
from autointent import Context, Dataset
1213
from autointent.configs import DataConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
@@ -52,8 +53,7 @@ def __init__(
5253
self.vector_index_config = VectorIndexConfig()
5354
self.data_config = DataConfig()
5455
elif not isinstance(nodes[0], InferenceNode):
55-
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
56-
raise TypeError(msg)
56+
assert_never(nodes)
5757

5858
def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None:
5959
"""
@@ -68,8 +68,7 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) ->
6868
elif isinstance(config, DataConfig):
6969
self.data_config = config
7070
else:
71-
msg = "unknown config type"
72-
raise TypeError(msg)
71+
assert_never(config)
7372

7473
@classmethod
7574
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed: int = 42) -> "Pipeline":
@@ -180,7 +179,7 @@ def fit(
180179
)
181180

182181
if sampler is None:
183-
sampler = self.sampler or "brute"
182+
sampler = self.sampler
184183

185184
self._fit(context, sampler)
186185

autointent/configs/_inference_node.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,3 @@ class InferenceNodeConfig:
1818
"""Configuration of the module"""
1919
load_path: str | None = None
2020
"""Path to the module dump. If None, the module will be trained from scratch"""
21-
22-
def __post_init__(self) -> None:
23-
if not isinstance(self.node_type, NodeType):
24-
self.node_type = NodeType(self.node_type)

autointent/context/_context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ def dump(self) -> None:
9494
optimization_results = self.optimization_info.dump_evaluation_results()
9595

9696
logs_dir = self.logging_config.dirpath
97-
if logs_dir is None:
98-
msg = "something's wrong with LoggingConfig"
99-
raise ValueError(msg)
100-
10197
logs_dir.mkdir(parents=True, exist_ok=True)
10298

10399
logs_path = logs_dir / "logs.json"

autointent/metrics/_converter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def transform(
2222
:param y_pred: Y_pred values
2323
:return:
2424
"""
25-
if isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray):
26-
return y_true, y_pred
2725
y_pred_ = np.array(y_pred)
2826
y_true_ = np.array(y_true)
2927
return y_true_, y_pred_

autointent/modules/abc/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]:
4040
Calculate metric on test set and return metric value.
4141
4242
:param context: Context to score
43-
:param split: Split to score on
43+
:param metrics: Metrics to score
4444
:return: Computed metrics value for the test set or error code of metrics
4545
"""
4646
if context.data_handler.config.scheme == "ho":
4747
return self.score_ho(context, metrics)
4848
if context.data_handler.config.scheme == "cv":
4949
return self.score_cv(context, metrics)
50-
msg = "Something's wrong with validation schemas"
51-
raise RuntimeError(msg)
50+
msg = f"Unknown scheme: {context.data_handler.config.scheme}"
51+
raise ValueError(msg)
5252

5353
@abstractmethod
5454
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: ...

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from optuna.trial import Trial
1313
from pydantic import BaseModel, Field
14+
from typing_extensions import assert_never
1415

1516
from autointent import Dataset
1617
from autointent.context import Context
@@ -65,11 +66,12 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
6566
Fit the node optimizer.
6667
6768
:param context: Context
69+
:param sampler: Sampler to use for optimization
6870
"""
6971
self._logger.info("starting %s node optimization...", self.node_info.node_type)
7072

7173
for search_space in deepcopy(self.modules_search_spaces):
72-
self._counter = 0
74+
self._counter: int = 0
7375
module_name = search_space.pop("module_name")
7476
n_trials = None
7577
if "n_trials" in search_space:
@@ -84,8 +86,7 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
8486
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
8587
n_trials = n_trials or 10
8688
else:
87-
msg = f"Unexpected sampler: {sampler}"
88-
raise ValueError(msg)
89+
assert_never(sampler)
8990
study = optuna.create_study(direction="maximize", sampler=sampler_instance)
9091
optuna.logging.set_verbosity(optuna.logging.WARNING)
9192
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)

autointent/schemas/_schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def get_prompt_config(self) -> dict[str, str] | None:
197197
prompts[TaskTypeEnum.sts.value] = self.sts_prompt
198198
return prompts if len(prompts) > 0 else None
199199

200-
def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # noqa: PLR0911
200+
def get_prompt_type(self, prompt_type: TaskTypeEnum | str | None) -> str | None: # noqa: PLR0911
201201
"""Get the prompt type for the given task type.
202202
203203
:param prompt_type: Task type for which to get the prompt.

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ skip_empty = true
166166
python_version = "3.10"
167167
strict = true
168168
warn_redundant_casts = true
169+
# align with mypy 2.0 release
170+
warn_unreachable = true
171+
local_partial_types = true
169172
plugins = [
170173
"pydantic.mypy",
171174
"numpy.typing.mypy_plugin",
@@ -193,3 +196,10 @@ module = [
193196
"wandb",
194197
]
195198
ignore_missing_imports = true
199+
200+
[[tool.mypy.overrides]]
201+
module = [
202+
"autointent._callbacks.*",
203+
"autointent.modules.abc.*",
204+
]
205+
warn_unreachable = false

0 commit comments

Comments
 (0)