11"""Schemes."""
22
3+ import functools
34import inspect
5+ import operator
46from collections .abc import Iterator
7+ from types import NoneType , UnionType
58from typing import Annotated , Any , Literal , TypeAlias , Union , get_args , get_origin , get_type_hints
69
7- from pydantic import BaseModel , Field , PositiveInt , RootModel
10+ from pydantic import BaseModel , ConfigDict , Field , PositiveInt , RootModel
811
9- from autointent .custom_types import NodeType
12+ from autointent .custom_types import NodeType , ParamSpaceFloat , ParamSpaceInt
1013from autointent .modules .abc import BaseModule
11- from autointent .nodes ._optimization ._node_optimizer import ParamSpaceFloat , ParamSpaceInt
1214from autointent .nodes .info import DecisionNodeInfo , EmbeddingNodeInfo , RegexNodeInfo , ScoringNodeInfo
1315
1416
@@ -26,17 +28,26 @@ def type_matches(target: type, tp: type) -> bool:
2628 """
2729 Recursively check if the target type is present in the given type.
2830
29- This function handles union types by unwrapping Annotated types where necessary.
31+ This function handles union types and generic types (e.g. dict[...] by checking
32+ their origin) after unwrapping Annotated types.
3033
31- :param target: Target type
32- :param tp: Given type
33- :return: If the target type is present in the given type
34+ :param target: Target type to check for.
35+ :param tp: Given type which may be a union, generic, or annotated type.
36+ :return: True if the target type is present in the given type.
3437 """
3538 origin = get_origin (tp )
36-
37- if origin is Union : # float | list[float]
39+ if origin is Union :
3840 return any (type_matches (target , arg ) for arg in get_args (tp ))
39- return unwrap_annotated (tp ) is target
41+
42+ # Unwrap Annotated types, if any.
43+ unwrapped = unwrap_annotated (tp )
44+
45+ # If the unwrapped type is a generic type, check its origin.
46+ generic_origin = get_origin (unwrapped )
47+ if generic_origin is not None :
48+ return generic_origin is target
49+
50+ return unwrapped is target
4051
4152
4253def get_optuna_class (param_type : type ) -> type [ParamSpaceInt | ParamSpaceFloat ] | None :
@@ -56,7 +67,12 @@ def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat]
5667 return None
5768
5869
59- def generate_models_and_union_type_for_classes (
70+ def to_union (types : list [type ]) -> type :
71+ """Convert a tuple of types into a union type."""
72+ return functools .reduce (operator .or_ , types )
73+
74+
75+ def generate_models_and_union_type_for_classes ( # noqa: PLR0912, C901
6076 classes : list [type [BaseModule ]],
6177) -> type [BaseModel ]:
6278 """Dynamically generates Pydantic models for class constructors and creates a union type."""
@@ -70,6 +86,7 @@ def generate_models_and_union_type_for_classes(
7086 fields = {
7187 "module_name" : (Literal [cls .name ], Field (...)),
7288 "n_trials" : (PositiveInt | None , Field (None , description = "Number of trials" )),
89+ "model_config" : (ConfigDict , ConfigDict (extra = "forbid" )),
7390 }
7491
7592 for param_name , param in init_signature .parameters .items ():
@@ -78,11 +95,33 @@ def generate_models_and_union_type_for_classes(
7895
7996 param_type : TypeAlias = type_hints .get (param_name , Any ) # type: ignore[valid-type] # noqa: PYI042
8097 field = Field (default = [param .default ]) if param .default is not inspect .Parameter .empty else Field (...)
81- search_type = get_optuna_class (param_type )
82- if search_type is None :
83- fields [param_name ] = (list [param_type ], field )
98+ if not type_matches (dict , param_type ):
99+ search_type = get_optuna_class (param_type )
100+ if search_type is None :
101+ fields [param_name ] = (list [param_type ], field )
102+ else :
103+ fields [param_name ] = (list [param_type ] | search_type , field )
84104 else :
85- fields [param_name ] = (list [param_type ] | search_type , field )
105+ dict_key_type , dict_values_types = get_args (param_type )
106+ is_optional = False
107+ if dict_values_types is NoneType : # if dict is optional
108+ is_optional = True
109+ dict_key_type , dict_values_types = get_args (dict_key_type )
110+ if get_origin (dict_values_types ) is UnionType :
111+ filed_types : list [type [Any ]] = []
112+ for value in get_args (dict_values_types ):
113+ filed_types .append (list [value ]) # type: ignore[valid-type]
114+ search_type = get_optuna_class (value )
115+ if search_type is not None :
116+ filed_types .append (search_type )
117+ filed_type = to_union (filed_types )
118+ else :
119+ filed_type = dict_values_types
120+
121+ if is_optional :
122+ fields [param_name ] = (dict [dict_key_type , filed_type ] | None , field ) # type: ignore[valid-type]
123+ else :
124+ fields [param_name ] = (dict [dict_key_type , filed_type ], field ) # type: ignore[valid-type]
86125
87126 model_name = f"{ cls .__name__ } InitModel"
88127 models [cls .__name__ ] = type (
0 commit comments