Skip to content

Commit 3fcf671

Browse files
committed
forgot something
1 parent 20cb72a commit 3fcf671

File tree

1 file changed

+69
-5
lines changed

1 file changed

+69
-5
lines changed

autointent/schemas/node_validation.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,97 @@
11
"""Schemes."""
22

33
import inspect
4+
from abc import ABC, abstractmethod
45
from collections.abc import Iterator
5-
from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints
6-
7-
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel, ValidationError, model_validator
6+
from typing import Annotated, Any, Literal, TypeAlias, TypeVar, Union, get_args, get_origin, get_type_hints
7+
8+
from pydantic import (
9+
BaseModel,
10+
ConfigDict,
11+
Field,
12+
PositiveInt,
13+
RootModel,
14+
ValidationError,
15+
ValidationInfo,
16+
field_validator,
17+
model_validator,
18+
)
819

920
from autointent.custom_types import NodeType
1021
from autointent.modules import BaseModule
1122
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
1223

1324

14-
class ParamSpaceInt(BaseModel):
25+
class ParamSpace(BaseModel, ABC):
26+
"""Base class for parameter search space configuration."""
27+
28+
@abstractmethod
29+
def n_possible_values(self) -> int | None:
30+
"""Calculate the number of possible values in the search space.
31+
32+
Returns:
33+
The number of possible values or None if search space is continuous.
34+
"""
35+
36+
37+
ParamSpaceT = TypeVar("ParamSpaceT", bound=ParamSpace)
38+
39+
40+
class ParamSpaceInt(ParamSpace):
1541
"""Integer parameter search space configuration."""
1642

1743
low: int = Field(..., description="Lower boundary of the search space.")
1844
high: int = Field(..., description="Upper boundary of the search space.")
1945
step: int = Field(1, description="Step size for the search space.")
2046
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
2147

48+
def n_possible_values(self) -> int:
49+
"""Calculate the number of possible values in the search space.
50+
51+
Returns:
52+
The number of possible values.
53+
"""
54+
return (self.high - self.low) // self.step + 1
55+
2256

23-
class ParamSpaceFloat(BaseModel):
57+
class ParamSpaceFloat(ParamSpace):
2458
"""Float parameter search space configuration."""
2559

2660
low: float = Field(..., description="Lower boundary of the search space.")
2761
high: float = Field(..., description="Upper boundary of the search space.")
2862
step: float | None = Field(None, description="Step size for the search space (if applicable).")
2963
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
3064

65+
@field_validator("step")
66+
@classmethod
67+
def validate_step_with_log(cls, v: float | None, info: ValidationInfo) -> float | None:
68+
"""Validate that step is not used when log is True.
69+
70+
Args:
71+
v: The step value to validate
72+
info: Validation info containing other field values
73+
74+
Returns:
75+
The validated step value
76+
77+
Raises:
78+
ValueError: If step is provided when log is True
79+
"""
80+
if info.data.get("log", False) and v is not None:
81+
msg = "Step cannot be used when log is True. See optuna docs on `suggest_float` (https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.suggest_float)."
82+
raise ValueError(msg)
83+
return v
84+
85+
def n_possible_values(self) -> int | None:
86+
"""Calculate the number of possible values in the search space.
87+
88+
Returns:
89+
The number of possible values or None if search space is continuous.
90+
"""
91+
if self.step is None:
92+
return None
93+
return int((self.high - self.low) // self.step) + 1
94+
3195

3296
def unwrap_annotated(tp: type) -> type:
3397
"""Unwrap the Annotated type to get the actual type.

0 commit comments

Comments
 (0)