|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING, Any |
| 4 | + |
| 5 | +from infrahub.core.constants import PathType |
| 6 | +from infrahub.core.path import DataPath, GroupedDataPaths |
| 7 | +from infrahub.core.schema.attribute_parameters import NumberPoolParameters |
| 8 | +from infrahub.core.validators.enum import ConstraintIdentifier |
| 9 | + |
| 10 | +from ..interface import ConstraintCheckerInterface |
| 11 | +from ..shared import AttributeSchemaValidatorQuery |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from infrahub.core.branch import Branch |
| 15 | + from infrahub.database import InfrahubDatabase |
| 16 | + |
| 17 | + from ..model import SchemaConstraintValidatorRequest |
| 18 | + |
| 19 | + |
| 20 | +class AttributeNumberPoolUpdateValidatorQuery(AttributeSchemaValidatorQuery): |
| 21 | + name: str = "attribute_constraints_numberpool_validator" |
| 22 | + |
| 23 | + async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002 |
| 24 | + branch_filter, branch_params = self.branch.get_query_filter_path(at=self.at.to_string()) |
| 25 | + self.params.update(branch_params) |
| 26 | + |
| 27 | + if not isinstance(self.attribute_schema.parameters, NumberPoolParameters): |
| 28 | + raise ValueError("attribute parameters are not a NumberPoolParameters") |
| 29 | + |
| 30 | + self.params["attr_name"] = self.attribute_schema.name |
| 31 | + self.params["start_range"] = self.attribute_schema.parameters.start_range |
| 32 | + self.params["end_range"] = self.attribute_schema.parameters.end_range |
| 33 | + |
| 34 | + query = """ |
| 35 | + MATCH (n:%(node_kind)s) |
| 36 | + CALL (n) { |
| 37 | + MATCH path = (root:Root)<-[rr:IS_PART_OF]-(n)-[ra:HAS_ATTRIBUTE]-(:Attribute { name: $attr_name } )-[rv:HAS_VALUE]-(av:AttributeValue) |
| 38 | + WHERE all( |
| 39 | + r in relationships(path) |
| 40 | + WHERE %(branch_filter)s |
| 41 | + ) |
| 42 | + RETURN path as full_path, n as node, rv as value_relationship, av.value as attribute_value |
| 43 | + ORDER BY rv.branch_level DESC, ra.branch_level DESC, rr.branch_level DESC, rv.from DESC, ra.from DESC, rr.from DESC |
| 44 | + LIMIT 1 |
| 45 | + } |
| 46 | + WITH full_path, node, attribute_value, value_relationship |
| 47 | + WHERE all(r in relationships(full_path) WHERE r.status = "active") |
| 48 | + AND ( |
| 49 | + (toInteger($start_range) IS NOT NULL AND attribute_value < toInteger($start_range)) |
| 50 | + OR (toInteger($end_range) IS NOT NULL AND attribute_value > toInteger($end_range)) |
| 51 | + ) |
| 52 | + """ % {"branch_filter": branch_filter, "node_kind": self.node_schema.kind} |
| 53 | + |
| 54 | + self.add_to_query(query) |
| 55 | + self.return_labels = ["node.uuid", "value_relationship", "attribute_value"] |
| 56 | + |
| 57 | + async def get_paths(self) -> GroupedDataPaths: |
| 58 | + grouped_data_paths = GroupedDataPaths() |
| 59 | + for result in self.results: |
| 60 | + grouped_data_paths.add_data_path( |
| 61 | + DataPath( |
| 62 | + branch=str(result.get("value_relationship").get("branch")), |
| 63 | + path_type=PathType.ATTRIBUTE, |
| 64 | + node_id=str(result.get("node.uuid")), |
| 65 | + field_name=self.attribute_schema.name, |
| 66 | + kind=self.node_schema.kind, |
| 67 | + value=result.get("attribute_value"), |
| 68 | + ), |
| 69 | + ) |
| 70 | + |
| 71 | + return grouped_data_paths |
| 72 | + |
| 73 | + |
| 74 | +class AttributeNumberPoolChecker(ConstraintCheckerInterface): |
| 75 | + query_classes = [AttributeNumberPoolUpdateValidatorQuery] |
| 76 | + |
| 77 | + def __init__(self, db: InfrahubDatabase, branch: Branch | None = None): |
| 78 | + self.db = db |
| 79 | + self.branch = branch |
| 80 | + |
| 81 | + @property |
| 82 | + def name(self) -> str: |
| 83 | + return "attribute.number.update" |
| 84 | + |
| 85 | + def supports(self, request: SchemaConstraintValidatorRequest) -> bool: |
| 86 | + return request.constraint_name in ( |
| 87 | + ConstraintIdentifier.ATTRIBUTE_PARAMETERS_START_RANGE_UPDATE.value, |
| 88 | + ConstraintIdentifier.ATTRIBUTE_PARAMETERS_END_RANGE_UPDATE.value, |
| 89 | + ) |
| 90 | + |
| 91 | + async def check(self, request: SchemaConstraintValidatorRequest) -> list[GroupedDataPaths]: |
| 92 | + grouped_data_paths_list: list[GroupedDataPaths] = [] |
| 93 | + if not request.schema_path.field_name: |
| 94 | + raise ValueError("field_name is not defined") |
| 95 | + attribute_schema = request.node_schema.get_attribute(name=request.schema_path.field_name) |
| 96 | + if not isinstance(attribute_schema.parameters, NumberPoolParameters): |
| 97 | + raise ValueError("attribute parameters are not a NumberPoolParameters") |
| 98 | + |
| 99 | + for query_class in self.query_classes: |
| 100 | + # TODO add exception handling |
| 101 | + query = await query_class.init( |
| 102 | + db=self.db, branch=self.branch, node_schema=request.node_schema, schema_path=request.schema_path |
| 103 | + ) |
| 104 | + await query.execute(db=self.db) |
| 105 | + grouped_data_paths_list.append(await query.get_paths()) |
| 106 | + return grouped_data_paths_list |
0 commit comments