Skip to content

Commit c9f1e23

Browse files
committed
update mypy config
1 parent 0694ebd commit c9f1e23

File tree

9 files changed

+16
-29
lines changed

9 files changed

+16
-29
lines changed

autointent/_callbacks/tensorboard.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,12 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5757
for key, value in module_kwargs.items():
5858
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
5959

60-
def log_value(self, **kwargs: dict[str, Any]) -> None:
60+
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
6161
"""
6262
Log data.
6363
6464
:param kwargs: Data to log.
6565
"""
66-
if self.module_writer is None:
67-
msg = "start_run must be called before log_value."
68-
raise RuntimeError(msg)
69-
7066
for key, value in kwargs.items():
7167
if isinstance(value, int | float):
7268
self.module_writer.add_scalar(key, value)
@@ -79,10 +75,6 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
7975
8076
:param metrics: Metrics to log.
8177
"""
82-
if self.module_writer is None:
83-
msg = "start_run must be called before log_value."
84-
raise RuntimeError(msg)
85-
8678
for key, value in metrics.items():
8779
if isinstance(value, int | float):
8880
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Pipeline:
2727

2828
def __init__(
2929
self,
30-
nodes: list[NodeOptimizer] | list[InferenceNode],
30+
nodes: list[NodeOptimizer] | list[InferenceNode] | list[Any],
3131
seed: int = 42,
3232
) -> None:
3333
"""
@@ -48,7 +48,7 @@ def __init__(
4848
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
4949
raise TypeError(msg)
5050

51-
def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig) -> None:
51+
def set_config(self, config: LoggingConfig | VectorIndexConfig | DataConfig | Any) -> None:
5252
"""
5353
Set configuration for the optimizer.
5454

autointent/configs/_inference_node.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,3 @@ class InferenceNodeConfig:
1818
"""Configuration of the module"""
1919
load_path: str | None = None
2020
"""Path to the module dump. If None, the module will be trained from scratch"""
21-
22-
def __post_init__(self) -> None:
23-
if not isinstance(self.node_type, NodeType):
24-
self.node_type = NodeType(self.node_type)

autointent/context/_context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ def dump(self) -> None:
9494
optimization_results = self.optimization_info.dump_evaluation_results()
9595

9696
logs_dir = self.logging_config.dirpath
97-
if logs_dir is None:
98-
msg = "something's wrong with LoggingConfig"
99-
raise ValueError(msg)
100-
10197
logs_dir.mkdir(parents=True, exist_ok=True)
10298

10399
logs_path = logs_dir / "logs.json"

autointent/metrics/_converter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def transform(
2222
:param y_pred: Y_pred values
2323
:return:
2424
"""
25-
if isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray):
26-
return y_true, y_pred
2725
y_pred_ = np.array(y_pred)
2826
y_true_ = np.array(y_true)
2927
return y_true_, y_pred_

autointent/modules/abc/_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ def score(self, context: Context, metrics: list[str]) -> dict[str, float]:
4040
Calculate metric on test set and return metric value.
4141
4242
:param context: Context to score
43-
:param split: Split to score on
43+
:param metrics: Metrics to score
4444
:return: Computed metrics value for the test set or error code of metrics
4545
"""
4646
if context.data_handler.config.scheme == "ho":
4747
return self.score_ho(context, metrics)
4848
if context.data_handler.config.scheme == "cv":
4949
return self.score_cv(context, metrics)
50-
msg = "Something's wrong with validation schemas"
51-
raise RuntimeError(msg)
5250

5351
@abstractmethod
5452
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]: ...

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
6565
Fit the node optimizer.
6666
6767
:param context: Context
68+
:param sampler: Sampler to use for optimization
6869
"""
6970
self._logger.info("starting %s node optimization...", self.node_info.node_type)
7071

7172
for search_space in deepcopy(self.modules_search_spaces):
72-
self._counter = 0
73+
self._counter: int = 0
7374
module_name = search_space.pop("module_name")
7475
n_trials = None
7576
if "n_trials" in search_space:
@@ -83,9 +84,6 @@ def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
8384
elif sampler == "random":
8485
sampler_instance = optuna.samplers.RandomSampler(seed=context.seed) # type: ignore[assignment]
8586
n_trials = n_trials or 10
86-
else:
87-
msg = f"Unexpected sampler: {sampler}"
88-
raise ValueError(msg)
8987
study = optuna.create_study(direction="maximize", sampler=sampler_instance)
9088
optuna.logging.set_verbosity(optuna.logging.WARNING)
9189
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)

autointent/schemas/_schemas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
218218
return self.sts_prompt
219219
if prompt_type == TaskTypeEnum.default:
220220
return self.default_prompt
221-
return None
222221

223222
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
224223

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ skip_empty = true
166166
python_version = "3.10"
167167
strict = true
168168
warn_redundant_casts = true
169+
# align with mypy 2.0 release
170+
warn_unreachable = true
171+
local_partial_types = true
169172
plugins = [
170173
"pydantic.mypy",
171174
"numpy.typing.mypy_plugin",
@@ -193,3 +196,10 @@ module = [
193196
"wandb",
194197
]
195198
ignore_missing_imports = true
199+
200+
[[tool.mypy.overrides]]
201+
module = [
202+
"autointent._callbacks.*",
203+
"autointent.modules.abc.*",
204+
]
205+
warn_unreachable = false

0 commit comments

Comments
 (0)