Skip to content

Commit fad689e

Browse files
committed
add node validators
1 parent 29de65d commit fad689e

File tree

2 files changed

+311
-21
lines changed

2 files changed

+311
-21
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,13 @@
1111
import optuna
1212
import torch
1313
from optuna.trial import Trial
14-
from pydantic import BaseModel, Field
1514
from typing_extensions import assert_never
1615

1716
from autointent import Dataset
1817
from autointent.context import Context
1918
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
2019
from autointent.nodes.info import NODES_INFO
21-
22-
23-
class ParamSpaceInt(BaseModel):
24-
"""Integer parameter search space configuration."""
25-
26-
low: int = Field(..., description="Lower boundary of the search space.")
27-
high: int = Field(..., description="Upper boundary of the search space.")
28-
step: int = Field(1, description="Step size for the search space.")
29-
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
30-
31-
32-
class ParamSpaceFloat(BaseModel):
33-
"""Float parameter search space configuration."""
34-
35-
low: float = Field(..., description="Lower boundary of the search space.")
36-
high: float = Field(..., description="Upper boundary of the search space.")
37-
step: float | None = Field(None, description="Step size for the search space (if applicable).")
38-
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
39-
20+
from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, SearchSpaceConfig
4021

4122
logger = logging.getLogger(__name__)
4223

@@ -270,7 +251,8 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
270251

271252
def validate_search_space(self, search_space: list[dict[str, Any]]) -> None:
272253
"""Check if search space is configured correctly."""
273-
for module_search_space in search_space:
254+
validated_search_space = SearchSpaceConfig(search_space).model_dump()
255+
for module_search_space in validated_search_space:
274256
module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space))
275257

276258
for params_combination in it.product(*module_search_space_no_optuna.values()):
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
"""Schemes."""
2+
3+
import inspect
4+
from collections.abc import Iterator
5+
from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints
6+
7+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel, ValidationError, model_validator
8+
9+
from autointent.custom_types import NodeType
10+
from autointent.modules import BaseModule
11+
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
12+
13+
14+
def unwrap_annotated(tp: type) -> type:
15+
"""
16+
Unwrap the Annotated type to get the actual type.
17+
18+
:param tp: Type to unwrap
19+
:return: Unwrapped type
20+
"""
21+
return get_args(tp)[0] if get_origin(tp) is Annotated else tp
22+
23+
24+
def type_matches(target: type, tp: type) -> bool:
25+
"""
26+
Recursively check if the target type is present in the given type.
27+
28+
This function handles union types by unwrapping Annotated types where necessary.
29+
30+
:param target: Target type
31+
:param tp: Given type
32+
:return: If the target type is present in the given type
33+
"""
34+
origin = get_origin(tp)
35+
36+
if origin is Union: # float | list[float]
37+
return any(type_matches(target, arg) for arg in get_args(tp))
38+
return unwrap_annotated(tp) is target
39+
40+
41+
class ParamSpaceInt(BaseModel):
42+
"""Integer parameter search space configuration."""
43+
44+
low: int = Field(..., description="Lower boundary of the search space.")
45+
high: int = Field(..., description="Upper boundary of the search space.")
46+
step: int = Field(1, description="Step size for the search space.")
47+
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
48+
49+
50+
class ParamSpaceFloat(BaseModel):
51+
"""Float parameter search space configuration."""
52+
53+
low: float = Field(..., description="Lower boundary of the search space.")
54+
high: float = Field(..., description="Upper boundary of the search space.")
55+
step: float | None = Field(None, description="Step size for the search space (if applicable).")
56+
log: bool = Field(False, description="Indicates whether to use a logarithmic scale.")
57+
58+
59+
def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None:
60+
"""
61+
Get the Optuna class for the given parameter type.
62+
63+
If the (possibly annotated or union) type includes int or float, this function
64+
returns the corresponding search space class.
65+
66+
:param param_type: Parameter type (could be a union, annotated type, or container)
67+
:return: ParamSpaceInt if the type matches int, ParamSpaceFloat if it matches float, else None.
68+
"""
69+
if type_matches(int, param_type):
70+
return ParamSpaceInt
71+
if type_matches(float, param_type):
72+
return ParamSpaceFloat
73+
return None
74+
75+
76+
def generate_models_and_union_type_for_classes(
77+
classes: list[type[BaseModule]],
78+
) -> tuple[type[BaseModel], dict[str, type[BaseModel]]]:
79+
"""Dynamically generates Pydantic models for class constructors and creates a union type."""
80+
models: dict[str, type[BaseModel]] = {}
81+
82+
for cls in classes:
83+
init_signature = inspect.signature(cls.from_context)
84+
globalns = getattr(cls.from_context, "__globals__", {})
85+
type_hints = get_type_hints(cls.from_context, globalns, None, include_extras=True) # Resolve forward refs
86+
87+
has_kwarg_arg = any(
88+
param.kind == inspect.Parameter.VAR_KEYWORD
89+
for param in init_signature.parameters.values()
90+
)
91+
92+
fields = {
93+
"module_name": (Literal[cls.name], Field(...)),
94+
"n_trials": (PositiveInt | None, Field(None, description="Number of trials")),
95+
"model_config": (ConfigDict, ConfigDict(extra="allow" if has_kwarg_arg else "forbid")),
96+
}
97+
98+
for param_name, param in init_signature.parameters.items():
99+
# skip self, cls, context, and **kwargs
100+
if param_name in ("self", "cls", "context") or param.kind == inspect.Parameter.VAR_KEYWORD:
101+
continue
102+
103+
param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042
104+
field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...)
105+
search_type = get_optuna_class(param_type)
106+
if search_type is None:
107+
fields[param_name] = (list[param_type], field)
108+
else:
109+
fields[param_name] = (list[param_type] | search_type, field)
110+
111+
model_name = f"{cls.__name__}InitModel"
112+
models[cls.name] = type(
113+
model_name,
114+
(BaseModel,),
115+
{
116+
"__annotations__": {k: v[0] for k, v in fields.items()},
117+
**{k: v[1] for k, v in fields.items()},
118+
},
119+
)
120+
121+
return Union[tuple(models.values())], models
122+
123+
124+
DecisionSearchSpaceType, DecisionNodesBaseModels = generate_models_and_union_type_for_classes( # type: ignore[valid-type]
125+
list(DecisionNodeInfo.modules_available.values())
126+
)
127+
DecisionMetrics = Literal[tuple(DecisionNodeInfo.metrics_available.keys())] # type: ignore[valid-type]
128+
129+
130+
class DecisionNodeValidator(BaseModel):
131+
"""Search space configuration for the Decision node."""
132+
133+
node_type: NodeType = NodeType.decision
134+
target_metric: DecisionMetrics
135+
metrics: list[DecisionMetrics] | None = None
136+
search_space: list[DecisionSearchSpaceType]
137+
138+
139+
EmbeddingSearchSpaceType, EmbeddingBaseModels = generate_models_and_union_type_for_classes(
140+
list(EmbeddingNodeInfo.modules_available.values())
141+
)
142+
EmbeddingMetrics: TypeAlias = Literal[tuple(EmbeddingNodeInfo.metrics_available.keys())] # type: ignore[valid-type]
143+
144+
145+
class EmbeddingNodeValidator(BaseModel):
146+
"""Search space configuration for the Embedding node."""
147+
148+
node_type: NodeType = NodeType.embedding
149+
target_metric: EmbeddingMetrics
150+
metrics: list[EmbeddingMetrics] | None = None
151+
search_space: list[EmbeddingSearchSpaceType]
152+
153+
154+
ScoringSearchSpaceType, ScoringNodesBaseModels = generate_models_and_union_type_for_classes(
155+
list(ScoringNodeInfo.modules_available.values())
156+
)
157+
ScoringMetrics: TypeAlias = Literal[tuple(ScoringNodeInfo.metrics_available.keys())] # type: ignore[valid-type]
158+
159+
160+
class ScoringNodeValidator(BaseModel):
161+
"""Search space configuration for the Scoring node."""
162+
163+
node_type: NodeType = NodeType.scoring
164+
target_metric: ScoringMetrics
165+
metrics: list[ScoringMetrics] | None = None
166+
search_space: list[ScoringSearchSpaceType]
167+
168+
169+
RegexpSearchSpaceType, RegexNodesBaseModels = generate_models_and_union_type_for_classes(
170+
list(RegexNodeInfo.modules_available.values())
171+
)
172+
RegexpMetrics: TypeAlias = Literal[tuple(RegexNodeInfo.metrics_available.keys())] # type: ignore[valid-type]
173+
174+
175+
class RegexNodeValidator(BaseModel):
176+
"""Search space configuration for the Regexp node."""
177+
178+
node_type: NodeType = NodeType.regex
179+
target_metric: RegexpMetrics
180+
metrics: list[RegexpMetrics] | None = None
181+
search_space: list[RegexpSearchSpaceType]
182+
183+
184+
SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator
185+
186+
187+
class SearchSpaceConfig(RootModel[list[DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType]]):
188+
"""Search space configuration."""
189+
190+
def __iter__(
191+
self,
192+
) -> Iterator[DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType]:
193+
"""Iterate over the root."""
194+
return iter(self.root)
195+
196+
def __getitem__(
197+
self, item: int
198+
) -> DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType:
199+
"""
200+
To get item directly from the root.
201+
202+
:param item: Index
203+
204+
:return: Item
205+
"""
206+
return self.root[item]
207+
208+
@model_validator(mode='before')
209+
@classmethod
210+
def validate_nodes(cls, data: list[Any]) -> list[Any]:
211+
if not isinstance(data, list):
212+
raise TypeError("The root must be a list of search space configurations.")
213+
error_message = ""
214+
for i, item in enumerate(data):
215+
if isinstance(item, BaseModel):
216+
continue
217+
if not isinstance(item, dict):
218+
raise TypeError("Each search space configuration must be a dictionary.")
219+
node_name = item.get("module_name")
220+
if node_name is None:
221+
error_message += f"Search space configuration at index {i} is missing 'module_name'.\n"
222+
continue
223+
224+
if node_name in DecisionNodesBaseModels:
225+
node_class = DecisionNodesBaseModels[node_name]
226+
elif node_name in EmbeddingBaseModels:
227+
node_class = EmbeddingBaseModels[node_name]
228+
elif node_name in ScoringNodesBaseModels:
229+
node_class = ScoringNodesBaseModels[node_name]
230+
elif node_name in RegexNodesBaseModels:
231+
node_class = RegexNodesBaseModels[node_name]
232+
else:
233+
error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n"
234+
break
235+
try:
236+
node_class(**item)
237+
except ValidationError as e:
238+
error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n"
239+
continue
240+
if len(error_message) > 0:
241+
raise TypeError(error_message)
242+
return data
243+
244+
245+
class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]):
246+
"""Optimizer configuration."""
247+
248+
def __iter__(
249+
self,
250+
) -> Iterator[SearchSpaceTypes]:
251+
"""Iterate over the root."""
252+
return iter(self.root)
253+
254+
def __getitem__(self, item: int) -> SearchSpaceTypes:
255+
"""
256+
To get item directly from the root.
257+
258+
:param item: Index
259+
260+
:return: Item
261+
"""
262+
return self.root[item]
263+
264+
265+
@model_validator(mode='before')
266+
@classmethod
267+
def validate_nodes(cls, data: list[Any]) -> list[Any]:
268+
if not isinstance(data, list):
269+
raise ValueError("The root must be a list of search space configurations.")
270+
error_message = ""
271+
for i, item in enumerate(data):
272+
if isinstance(item, BaseModel):
273+
continue
274+
if not isinstance(item, dict):
275+
raise ValueError("Each search space configuration must be a dictionary.")
276+
if "node_type" not in item:
277+
raise ValueError("Each search space configuration must have a 'node_type' key.")
278+
if not isinstance(item.get("search_space"), list):
279+
raise ValueError("Each search space configuration must have a 'search_space' key of type list.")
280+
for search_space in item["search_space"]:
281+
node_name = search_space.get("module_name")
282+
if node_name is None:
283+
error_message += f"Search space configuration at index {i} is missing 'module_name'.\n"
284+
continue
285+
if item["node_type"] == NodeType.decision.value:
286+
node_class = DecisionNodesBaseModels.get(node_name)
287+
elif item["node_type"] == NodeType.embedding.value:
288+
node_class = EmbeddingBaseModels.get(node_name)
289+
elif item["node_type"] == NodeType.scoring.value:
290+
node_class = ScoringNodesBaseModels.get(node_name)
291+
elif item["node_type"] == NodeType.regex.value:
292+
node_class = RegexNodesBaseModels.get(node_name)
293+
else:
294+
error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n"
295+
break
296+
297+
try:
298+
node_class(**search_space)
299+
except ValidationError as e:
300+
error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n"
301+
continue
302+
if len(error_message) > 0:
303+
raise ValueError(error_message)
304+
return data
305+
306+
307+
308+

0 commit comments

Comments
 (0)