|
| 1 | +"""Schemes.""" |
| 2 | + |
| 3 | +import inspect |
| 4 | +from collections.abc import Iterator |
| 5 | +from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints |
| 6 | + |
| 7 | +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel, ValidationError, model_validator |
| 8 | + |
| 9 | +from autointent.custom_types import NodeType |
| 10 | +from autointent.modules import BaseModule |
| 11 | +from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo |
| 12 | + |
| 13 | + |
| 14 | +def unwrap_annotated(tp: type) -> type: |
| 15 | + """ |
| 16 | + Unwrap the Annotated type to get the actual type. |
| 17 | +
|
| 18 | + :param tp: Type to unwrap |
| 19 | + :return: Unwrapped type |
| 20 | + """ |
| 21 | + return get_args(tp)[0] if get_origin(tp) is Annotated else tp |
| 22 | + |
| 23 | + |
| 24 | +def type_matches(target: type, tp: type) -> bool: |
| 25 | + """ |
| 26 | + Recursively check if the target type is present in the given type. |
| 27 | +
|
| 28 | + This function handles union types by unwrapping Annotated types where necessary. |
| 29 | +
|
| 30 | + :param target: Target type |
| 31 | + :param tp: Given type |
| 32 | + :return: If the target type is present in the given type |
| 33 | + """ |
| 34 | + origin = get_origin(tp) |
| 35 | + |
| 36 | + if origin is Union: # float | list[float] |
| 37 | + return any(type_matches(target, arg) for arg in get_args(tp)) |
| 38 | + return unwrap_annotated(tp) is target |
| 39 | + |
| 40 | + |
| 41 | +class ParamSpaceInt(BaseModel): |
| 42 | + """Integer parameter search space configuration.""" |
| 43 | + |
| 44 | + low: int = Field(..., description="Lower boundary of the search space.") |
| 45 | + high: int = Field(..., description="Upper boundary of the search space.") |
| 46 | + step: int = Field(1, description="Step size for the search space.") |
| 47 | + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
| 48 | + |
| 49 | + |
| 50 | +class ParamSpaceFloat(BaseModel): |
| 51 | + """Float parameter search space configuration.""" |
| 52 | + |
| 53 | + low: float = Field(..., description="Lower boundary of the search space.") |
| 54 | + high: float = Field(..., description="Upper boundary of the search space.") |
| 55 | + step: float | None = Field(None, description="Step size for the search space (if applicable).") |
| 56 | + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") |
| 57 | + |
| 58 | + |
| 59 | +def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None: |
| 60 | + """ |
| 61 | + Get the Optuna class for the given parameter type. |
| 62 | +
|
| 63 | + If the (possibly annotated or union) type includes int or float, this function |
| 64 | + returns the corresponding search space class. |
| 65 | +
|
| 66 | + :param param_type: Parameter type (could be a union, annotated type, or container) |
| 67 | + :return: ParamSpaceInt if the type matches int, ParamSpaceFloat if it matches float, else None. |
| 68 | + """ |
| 69 | + if type_matches(int, param_type): |
| 70 | + return ParamSpaceInt |
| 71 | + if type_matches(float, param_type): |
| 72 | + return ParamSpaceFloat |
| 73 | + return None |
| 74 | + |
| 75 | + |
| 76 | +def generate_models_and_union_type_for_classes( |
| 77 | + classes: list[type[BaseModule]], |
| 78 | +) -> tuple[type[BaseModel], dict[str, type[BaseModel]]]: |
| 79 | + """Dynamically generates Pydantic models for class constructors and creates a union type.""" |
| 80 | + models: dict[str, type[BaseModel]] = {} |
| 81 | + |
| 82 | + for cls in classes: |
| 83 | + init_signature = inspect.signature(cls.from_context) |
| 84 | + globalns = getattr(cls.from_context, "__globals__", {}) |
| 85 | + type_hints = get_type_hints(cls.from_context, globalns, None, include_extras=True) # Resolve forward refs |
| 86 | + |
| 87 | + has_kwarg_arg = any( |
| 88 | + param.kind == inspect.Parameter.VAR_KEYWORD |
| 89 | + for param in init_signature.parameters.values() |
| 90 | + ) |
| 91 | + |
| 92 | + fields = { |
| 93 | + "module_name": (Literal[cls.name], Field(...)), |
| 94 | + "n_trials": (PositiveInt | None, Field(None, description="Number of trials")), |
| 95 | + "model_config": (ConfigDict, ConfigDict(extra="allow" if has_kwarg_arg else "forbid")), |
| 96 | + } |
| 97 | + |
| 98 | + for param_name, param in init_signature.parameters.items(): |
| 99 | + # skip self, cls, context, and **kwargs |
| 100 | + if param_name in ("self", "cls", "context") or param.kind == inspect.Parameter.VAR_KEYWORD: |
| 101 | + continue |
| 102 | + |
| 103 | + param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042 |
| 104 | + field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...) |
| 105 | + search_type = get_optuna_class(param_type) |
| 106 | + if search_type is None: |
| 107 | + fields[param_name] = (list[param_type], field) |
| 108 | + else: |
| 109 | + fields[param_name] = (list[param_type] | search_type, field) |
| 110 | + |
| 111 | + model_name = f"{cls.__name__}InitModel" |
| 112 | + models[cls.name] = type( |
| 113 | + model_name, |
| 114 | + (BaseModel,), |
| 115 | + { |
| 116 | + "__annotations__": {k: v[0] for k, v in fields.items()}, |
| 117 | + **{k: v[1] for k, v in fields.items()}, |
| 118 | + }, |
| 119 | + ) |
| 120 | + |
| 121 | + return Union[tuple(models.values())], models |
| 122 | + |
| 123 | + |
| 124 | +DecisionSearchSpaceType, DecisionNodesBaseModels = generate_models_and_union_type_for_classes( # type: ignore[valid-type] |
| 125 | + list(DecisionNodeInfo.modules_available.values()) |
| 126 | +) |
| 127 | +DecisionMetrics = Literal[tuple(DecisionNodeInfo.metrics_available.keys())] # type: ignore[valid-type] |
| 128 | + |
| 129 | + |
| 130 | +class DecisionNodeValidator(BaseModel): |
| 131 | + """Search space configuration for the Decision node.""" |
| 132 | + |
| 133 | + node_type: NodeType = NodeType.decision |
| 134 | + target_metric: DecisionMetrics |
| 135 | + metrics: list[DecisionMetrics] | None = None |
| 136 | + search_space: list[DecisionSearchSpaceType] |
| 137 | + |
| 138 | + |
| 139 | +EmbeddingSearchSpaceType, EmbeddingBaseModels = generate_models_and_union_type_for_classes( |
| 140 | + list(EmbeddingNodeInfo.modules_available.values()) |
| 141 | +) |
| 142 | +EmbeddingMetrics: TypeAlias = Literal[tuple(EmbeddingNodeInfo.metrics_available.keys())] # type: ignore[valid-type] |
| 143 | + |
| 144 | + |
| 145 | +class EmbeddingNodeValidator(BaseModel): |
| 146 | + """Search space configuration for the Embedding node.""" |
| 147 | + |
| 148 | + node_type: NodeType = NodeType.embedding |
| 149 | + target_metric: EmbeddingMetrics |
| 150 | + metrics: list[EmbeddingMetrics] | None = None |
| 151 | + search_space: list[EmbeddingSearchSpaceType] |
| 152 | + |
| 153 | + |
| 154 | +ScoringSearchSpaceType, ScoringNodesBaseModels = generate_models_and_union_type_for_classes( |
| 155 | + list(ScoringNodeInfo.modules_available.values()) |
| 156 | +) |
| 157 | +ScoringMetrics: TypeAlias = Literal[tuple(ScoringNodeInfo.metrics_available.keys())] # type: ignore[valid-type] |
| 158 | + |
| 159 | + |
| 160 | +class ScoringNodeValidator(BaseModel): |
| 161 | + """Search space configuration for the Scoring node.""" |
| 162 | + |
| 163 | + node_type: NodeType = NodeType.scoring |
| 164 | + target_metric: ScoringMetrics |
| 165 | + metrics: list[ScoringMetrics] | None = None |
| 166 | + search_space: list[ScoringSearchSpaceType] |
| 167 | + |
| 168 | + |
| 169 | +RegexpSearchSpaceType, RegexNodesBaseModels = generate_models_and_union_type_for_classes( |
| 170 | + list(RegexNodeInfo.modules_available.values()) |
| 171 | +) |
| 172 | +RegexpMetrics: TypeAlias = Literal[tuple(RegexNodeInfo.metrics_available.keys())] # type: ignore[valid-type] |
| 173 | + |
| 174 | + |
| 175 | +class RegexNodeValidator(BaseModel): |
| 176 | + """Search space configuration for the Regexp node.""" |
| 177 | + |
| 178 | + node_type: NodeType = NodeType.regex |
| 179 | + target_metric: RegexpMetrics |
| 180 | + metrics: list[RegexpMetrics] | None = None |
| 181 | + search_space: list[RegexpSearchSpaceType] |
| 182 | + |
| 183 | + |
| 184 | +SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator |
| 185 | + |
| 186 | + |
| 187 | +class SearchSpaceConfig(RootModel[list[DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType]]): |
| 188 | + """Search space configuration.""" |
| 189 | + |
| 190 | + def __iter__( |
| 191 | + self, |
| 192 | + ) -> Iterator[DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType]: |
| 193 | + """Iterate over the root.""" |
| 194 | + return iter(self.root) |
| 195 | + |
| 196 | + def __getitem__( |
| 197 | + self, item: int |
| 198 | + ) -> DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType: |
| 199 | + """ |
| 200 | + To get item directly from the root. |
| 201 | +
|
| 202 | + :param item: Index |
| 203 | +
|
| 204 | + :return: Item |
| 205 | + """ |
| 206 | + return self.root[item] |
| 207 | + |
| 208 | + @model_validator(mode='before') |
| 209 | + @classmethod |
| 210 | + def validate_nodes(cls, data: list[Any]) -> list[Any]: |
| 211 | + if not isinstance(data, list): |
| 212 | + raise TypeError("The root must be a list of search space configurations.") |
| 213 | + error_message = "" |
| 214 | + for i, item in enumerate(data): |
| 215 | + if isinstance(item, BaseModel): |
| 216 | + continue |
| 217 | + if not isinstance(item, dict): |
| 218 | + raise TypeError("Each search space configuration must be a dictionary.") |
| 219 | + node_name = item.get("module_name") |
| 220 | + if node_name is None: |
| 221 | + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" |
| 222 | + continue |
| 223 | + |
| 224 | + if node_name in DecisionNodesBaseModels: |
| 225 | + node_class = DecisionNodesBaseModels[node_name] |
| 226 | + elif node_name in EmbeddingBaseModels: |
| 227 | + node_class = EmbeddingBaseModels[node_name] |
| 228 | + elif node_name in ScoringNodesBaseModels: |
| 229 | + node_class = ScoringNodesBaseModels[node_name] |
| 230 | + elif node_name in RegexNodesBaseModels: |
| 231 | + node_class = RegexNodesBaseModels[node_name] |
| 232 | + else: |
| 233 | + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" |
| 234 | + break |
| 235 | + try: |
| 236 | + node_class(**item) |
| 237 | + except ValidationError as e: |
| 238 | + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" |
| 239 | + continue |
| 240 | + if len(error_message) > 0: |
| 241 | + raise TypeError(error_message) |
| 242 | + return data |
| 243 | + |
| 244 | + |
| 245 | +class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]): |
| 246 | + """Optimizer configuration.""" |
| 247 | + |
| 248 | + def __iter__( |
| 249 | + self, |
| 250 | + ) -> Iterator[SearchSpaceTypes]: |
| 251 | + """Iterate over the root.""" |
| 252 | + return iter(self.root) |
| 253 | + |
| 254 | + def __getitem__(self, item: int) -> SearchSpaceTypes: |
| 255 | + """ |
| 256 | + To get item directly from the root. |
| 257 | +
|
| 258 | + :param item: Index |
| 259 | +
|
| 260 | + :return: Item |
| 261 | + """ |
| 262 | + return self.root[item] |
| 263 | + |
| 264 | + |
| 265 | + @model_validator(mode='before') |
| 266 | + @classmethod |
| 267 | + def validate_nodes(cls, data: list[Any]) -> list[Any]: |
| 268 | + if not isinstance(data, list): |
| 269 | + raise ValueError("The root must be a list of search space configurations.") |
| 270 | + error_message = "" |
| 271 | + for i, item in enumerate(data): |
| 272 | + if isinstance(item, BaseModel): |
| 273 | + continue |
| 274 | + if not isinstance(item, dict): |
| 275 | + raise ValueError("Each search space configuration must be a dictionary.") |
| 276 | + if "node_type" not in item: |
| 277 | + raise ValueError("Each search space configuration must have a 'node_type' key.") |
| 278 | + if not isinstance(item.get("search_space"), list): |
| 279 | + raise ValueError("Each search space configuration must have a 'search_space' key of type list.") |
| 280 | + for search_space in item["search_space"]: |
| 281 | + node_name = search_space.get("module_name") |
| 282 | + if node_name is None: |
| 283 | + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" |
| 284 | + continue |
| 285 | + if item["node_type"] == NodeType.decision.value: |
| 286 | + node_class = DecisionNodesBaseModels.get(node_name) |
| 287 | + elif item["node_type"] == NodeType.embedding.value: |
| 288 | + node_class = EmbeddingBaseModels.get(node_name) |
| 289 | + elif item["node_type"] == NodeType.scoring.value: |
| 290 | + node_class = ScoringNodesBaseModels.get(node_name) |
| 291 | + elif item["node_type"] == NodeType.regex.value: |
| 292 | + node_class = RegexNodesBaseModels.get(node_name) |
| 293 | + else: |
| 294 | + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" |
| 295 | + break |
| 296 | + |
| 297 | + try: |
| 298 | + node_class(**search_space) |
| 299 | + except ValidationError as e: |
| 300 | + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" |
| 301 | + continue |
| 302 | + if len(error_message) > 0: |
| 303 | + raise ValueError(error_message) |
| 304 | + return data |
| 305 | + |
| 306 | + |
| 307 | + |
| 308 | + |
0 commit comments