22
33import enum
44from enum import Enum
5- from typing import TYPE_CHECKING , Any
5+ from typing import TYPE_CHECKING , Any , Self
66
7- from pydantic import field_validator , model_validator
7+ from pydantic import Field , ValidationInfo , field_validator , model_validator
88
99from infrahub import config
10+ from infrahub .core .constants .schema import UpdateSupport
1011from infrahub .core .enums import generate_python_enum
1112from infrahub .core .query .attribute import default_attribute_query_filter
1213from infrahub .types import ATTRIBUTE_KIND_LABELS , ATTRIBUTE_TYPES
1314
15+ from .attribute_parameters import AttributeParameters , TextAttributeParameters , get_attribute_parameters_class_for_kind
1416from .generated .attribute_schema import GeneratedAttributeSchema
1517
1618if TYPE_CHECKING :
2123 from infrahub .database import InfrahubDatabase
2224
2325
26+ def get_attribute_schema_class_for_kind (kind : str ) -> type [AttributeSchema ]:
27+ attribute_schema_class_by_kind : dict [str , type [AttributeSchema ]] = {
28+ "Text" : TextAttributeSchema ,
29+ "TextArea" : TextAttributeSchema ,
30+ }
31+ return attribute_schema_class_by_kind .get (kind , AttributeSchema )
32+
33+
2434class AttributeSchema (GeneratedAttributeSchema ):
2535 _sort_by : list [str ] = ["name" ]
2636 _enum_class : type [enum .Enum ] | None = None
@@ -53,16 +63,36 @@ def kind_options(cls, v: str) -> str:
5363
5464 @model_validator (mode = "before" )
5565 @classmethod
56- def validate_dropdown_choices (cls , values : dict [ str , Any ] ) -> dict [ str , Any ] :
66+ def validate_dropdown_choices (cls , values : Any ) -> Any :
5767 """Validate that choices are defined for a dropdown but not for other kinds."""
58- if values .get ("kind" ) != "Dropdown" and values .get ("choices" ):
59- raise ValueError (f"Can only specify 'choices' for kind=Dropdown: { values ['kind' ]} " )
60-
61- if values .get ("kind" ) == "Dropdown" and not values .get ("choices" ):
68+ if isinstance (values , dict ):
69+ kind = values .get ("kind" )
70+ choices = values .get ("choices" )
71+ elif isinstance (values , AttributeSchema ):
72+ kind = values .kind
73+ choices = values .choices
74+ else :
75+ return values
76+ if kind != "Dropdown" and choices :
77+ raise ValueError (f"Can only specify 'choices' for kind=Dropdown: { kind } " )
78+
79+ if kind == "Dropdown" and not choices :
6280 raise ValueError ("The property 'choices' is required for kind=Dropdown" )
6381
6482 return values
6583
84+ @field_validator ("parameters" , mode = "before" )
85+ @classmethod
86+ def set_parameters_type (cls , value : Any , info : ValidationInfo ) -> Any :
87+ """Override parameters class if using base AttributeParameters class and should be using a subclass"""
88+ kind = info .data ["kind" ]
89+ expected_parameters_class = get_attribute_parameters_class_for_kind (kind = kind )
90+ if value is None :
91+ return expected_parameters_class ()
92+ if not isinstance (value , expected_parameters_class ) and isinstance (value , AttributeParameters ):
93+ return expected_parameters_class (** value .model_dump ())
94+ return value
95+
6696 def get_class (self ) -> type [BaseAttribute ]:
6797 return ATTRIBUTE_TYPES [self .kind ].get_infrahub_class ()
6898
@@ -106,7 +136,7 @@ def update_from_generic(self, other: AttributeSchema) -> None:
106136
107137 def to_node (self ) -> dict [str , Any ]:
108138 fields_to_exclude = {"id" , "state" , "filters" }
109- fields_to_json = {"computed_attribute" }
139+ fields_to_json = {"computed_attribute" , "parameters" }
110140 data = self .model_dump (exclude = fields_to_exclude | fields_to_json )
111141
112142 for field_name in fields_to_json :
@@ -117,6 +147,15 @@ def to_node(self) -> dict[str, Any]:
117147
118148 return data
119149
150+ def get_regex (self ) -> str | None :
151+ return self .regex
152+
153+ def get_min_length (self ) -> int | None :
154+ return self .min_length
155+
156+ def get_max_length (self ) -> int | None :
157+ return self .max_length
158+
120159 async def get_query_filter (
121160 self ,
122161 name : str ,
@@ -144,3 +183,39 @@ async def get_query_filter(
144183 partial_match = partial_match ,
145184 support_profiles = support_profiles ,
146185 )
186+
187+
188+ class TextAttributeSchema (AttributeSchema ):
189+ parameters : TextAttributeParameters = Field (
190+ default_factory = TextAttributeParameters ,
191+ description = "Extra parameters specific to text attributes" ,
192+ json_schema_extra = {"update" : UpdateSupport .VALIDATE_CONSTRAINT .value },
193+ )
194+
195+ @model_validator (mode = "after" )
196+ def reconcile_parameters (self ) -> Self :
197+ if self .regex != self .parameters .regex :
198+ final_regex = self .parameters .regex or self .regex
199+ if not final_regex : # falsy parameters.regex override falsy regex
200+ final_regex = self .parameters .regex
201+ self .regex = self .parameters .regex = final_regex
202+ if self .min_length != self .parameters .min_length :
203+ final_min_length = self .parameters .min_length or self .min_length
204+ if not final_min_length : # falsy parameters.min_length override falsy min_length
205+ final_min_length = self .parameters .min_length
206+ self .min_length = self .parameters .min_length = final_min_length
207+ if self .max_length != self .parameters .max_length :
208+ final_max_length = self .parameters .max_length or self .max_length
209+ if not final_max_length : # falsy parameters.max_length override falsy max_length
210+ final_max_length = self .parameters .max_length
211+ self .max_length = self .parameters .max_length = final_max_length
212+ return self
213+
214+ def get_regex (self ) -> str | None :
215+ return self .parameters .regex
216+
217+ def get_min_length (self ) -> int | None :
218+ return self .parameters .min_length
219+
220+ def get_max_length (self ) -> int | None :
221+ return self .parameters .max_length
0 commit comments