Skip to content

Commit 63206fb

Browse files
committed
simplify fields
1 parent 0e137b0 commit 63206fb

File tree

4 files changed

+7
-11
lines changed

4 files changed

+7
-11
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from autointent.custom_types import ListOfGenericLabels, NodeType
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
16+
from autointent.nodes.schemes import OptimizerConfig
1617
from autointent.utils import load_default_search_space, load_search_space
1718

1819
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
19-
from ..nodes.schemes import OptimizerConfig
2020

2121
if TYPE_CHECKING:
2222
from autointent.modules.abc import DecisionModule, ScoringModule
@@ -77,8 +77,8 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
7777
"""
7878
if isinstance(search_space, Path | str):
7979
search_space = load_search_space(search_space)
80-
search_space = OptimizerConfig(search_space).model_dump()
81-
nodes = [NodeOptimizer(**node) for node in search_space]
80+
validated_search_space = OptimizerConfig(search_space).model_dump() # type: ignore[arg-type]
81+
nodes = [NodeOptimizer(**node) for node in validated_search_space]
8282
return cls(nodes=nodes, seed=seed)
8383

8484
@classmethod

autointent/modules/decision/_threshold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
8383
:param thresh: Threshold for the scores, shape (n_classes,) or float
8484
"""
85-
self.thresh = thresh
85+
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)
8686

8787
@classmethod
8888
def from_context(cls, context: Context, thresh: float | list[float] = 0.5) -> "ThresholdDecision":

autointent/nodes/schemes.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,9 @@ def generate_models_and_union_type_for_classes(
2929
continue
3030

3131
param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042
32-
default = param.default if param.default is not inspect.Parameter.empty else None
32+
field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...)
3333

34-
# Ensure fields with defaults have the correct type
35-
if default is None:
36-
param_type = param_type | None
37-
38-
fields[param_name] = (list[param_type], Field(default=default)) # type: ignore[assignment]
34+
fields[param_name] = (list[param_type], field) # type: ignore[assignment]
3935

4036
model_name = f"{cls.__name__}InitModel"
4137
models[cls.__name__] = type(

tests/configs/test_combined_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_valid_optimizer_config(valid_optimizer_config):
5151
"task_type",
5252
["multiclass", "multilabel"],
5353
)
54-
def test_inference_config(task_type):
54+
def test_optimizer_config(task_type):
5555
search_space = get_search_space(task_type)
5656
config = OptimizerConfig(search_space)
5757
assert config

0 commit comments

Comments
 (0)