diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e2c7f3398e..efe40aeaa5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1020,7 +1020,7 @@ jobs: uses: CodSpeedHQ/action@v3 with: token: ${{ secrets.CODSPEED_TOKEN }} - run: "poetry run pytest -v backend/tests/benchmark/ --codspeed" + run: "poetry run pytest -vvv backend/tests/benchmark/ --benchmark-verbose --codspeed" - name: Run non-intensive benchmarks if: | @@ -1028,7 +1028,7 @@ jobs: uses: CodSpeedHQ/action@v3 with: token: ${{ secrets.CODSPEED_TOKEN }} - run: "poetry run pytest -v backend/tests/benchmark/ --codspeed --ignore=backend/tests/benchmark/intensive" + run: "poetry run pytest -vvv backend/tests/benchmark/ --benchmark-verbose --codspeed --ignore=backend/tests/benchmark/intensive" # ------------------------------------------ Coverall Report ------------------------------------------ coverall-report: diff --git a/backend/infrahub/core/constraint/node/runner.py b/backend/infrahub/core/constraint/node/runner.py index 7a1d01da5b..5b562a9acd 100644 --- a/backend/infrahub/core/constraint/node/runner.py +++ b/backend/infrahub/core/constraint/node/runner.py @@ -2,7 +2,7 @@ from infrahub.core.branch import Branch from infrahub.core.node import Node -from infrahub.core.node.constraints.interface import NodeConstraintInterface +from infrahub.core.node.constraints.grouped_uniqueness import NodeGroupedUniquenessConstraint from infrahub.core.relationship.constraints.interface import RelationshipManagerConstraintInterface from infrahub.database import InfrahubDatabase @@ -15,21 +15,18 @@ def __init__( self, db: InfrahubDatabase, branch: Branch, - node_constraints: list[NodeConstraintInterface], + uniqueness_constraint: NodeGroupedUniquenessConstraint, relationship_manager_constraints: list[RelationshipManagerConstraintInterface], ) -> None: self.db = db self.branch = branch - self.node_constraints = node_constraints + self.uniqueness_constraint = uniqueness_constraint self.relationship_manager_constraints = relationship_manager_constraints async def check(self, node: Node, field_filters: list[str] | None = None) -> None: async with self.db.start_session() as db: await node.resolve_relationships(db=db) - for node_constraint in self.node_constraints: - await node_constraint.check(node, filters=field_filters) - for relationship_name in node.get_schema().relationship_names: if field_filters and relationship_name not in field_filters: continue @@ -37,3 +34,7 @@ async def check(self, node: Node, field_filters: list[str] | None = None) -> Non await relationship_manager.fetch_relationship_ids(db=db, force_refresh=True) for relationship_constraint in self.relationship_manager_constraints: await relationship_constraint.check(relm=relationship_manager, node_schema=node.get_schema()) + + # If HFID constraint is the only constraint violated, all other constraints need to have ran before, + # as it means there is an existing node that we might want to update in the case of an upsert + await self.uniqueness_constraint.check(node, filters=field_filters) diff --git a/backend/infrahub/core/migrations/graph/m018_uniqueness_nulls.py b/backend/infrahub/core/migrations/graph/m018_uniqueness_nulls.py index d869df437a..1dc41c6621 100644 --- a/backend/infrahub/core/migrations/graph/m018_uniqueness_nulls.py +++ b/backend/infrahub/core/migrations/graph/m018_uniqueness_nulls.py @@ -55,13 +55,13 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: if not isinstance(schema, NodeSchema | GenericSchema): continue - schema_constraint_path_groups = schema.get_unique_constraint_schema_attribute_paths( + uniqueness_constraint_paths = schema.get_unique_constraint_schema_attribute_paths( schema_branch=schema_branch ) includes_optional_attr: bool = False - for constraint_group in schema_constraint_path_groups: - for schema_attribute_path in constraint_group: + for uniqueness_constraint_path in uniqueness_constraint_paths: + for schema_attribute_path in uniqueness_constraint_path.attributes_paths: if ( schema_attribute_path.attribute_schema and schema_attribute_path.attribute_schema.optional is True diff --git a/backend/infrahub/core/node/constraints/grouped_uniqueness.py b/backend/infrahub/core/node/constraints/grouped_uniqueness.py index 3fe06350b9..8a49061dd1 100644 --- a/backend/infrahub/core/node/constraints/grouped_uniqueness.py +++ b/backend/infrahub/core/node/constraints/grouped_uniqueness.py @@ -9,6 +9,11 @@ SchemaAttributePath, SchemaAttributePathValue, ) +from infrahub.core.schema.basenode_schema import ( + SchemaUniquenessConstraintPath, + UniquenessConstraintType, + UniquenessConstraintViolation, +) from infrahub.core.validators.uniqueness.index import UniquenessQueryResultsIndex from infrahub.core.validators.uniqueness.model import ( NodeUniquenessQueryRequest, @@ -16,7 +21,7 @@ QueryRelationshipAttributePath, ) from infrahub.core.validators.uniqueness.query import NodeUniqueAttributeConstraintQuery -from infrahub.exceptions import ValidationError +from infrahub.exceptions import HFIDViolatedError, ValidationError from .interface import NodeConstraintInterface @@ -39,15 +44,15 @@ async def _build_query_request( self, updated_node: Node, node_schema: MainSchemaTypes, - path_groups: list[list[SchemaAttributePath]], + uniqueness_constraint_paths: list[SchemaUniquenessConstraintPath], filters: list[str] | None = None, ) -> NodeUniquenessQueryRequest: query_request = NodeUniquenessQueryRequest(kind=node_schema.kind) - for path_group in path_groups: + for uniqueness_constraint_path in uniqueness_constraint_paths: include_in_query = not filters query_relationship_paths: set[QueryRelationshipAttributePath] = set() query_attribute_paths: set[QueryAttributePath] = set() - for attribute_path in path_group: + for attribute_path in uniqueness_constraint_path.attributes_paths: if attribute_path.related_schema and attribute_path.relationship_schema: if filters and attribute_path.relationship_schema.name in filters: include_in_query = True @@ -118,63 +123,77 @@ async def _get_node_attribute_path_values( ) return node_value_combination - def _check_one_constraint_group( - self, schema_attribute_path_values: list[SchemaAttributePathValue], results_index: UniquenessQueryResultsIndex - ) -> None: - # constraint cannot be violated if this node is missing any values - if any(sapv.value is None for sapv in schema_attribute_path_values): - return - - matching_node_ids = results_index.get_node_ids_for_value_group(schema_attribute_path_values) - if not matching_node_ids: - return - uniqueness_constraint_fields = [] - for sapv in schema_attribute_path_values: - if sapv.relationship_schema: - uniqueness_constraint_fields.append(sapv.relationship_schema.name) - elif sapv.attribute_schema: - uniqueness_constraint_fields.append(sapv.attribute_schema.name) - uniqueness_constraint_string = "-".join(uniqueness_constraint_fields) - error_msg = f"Violates uniqueness constraint '{uniqueness_constraint_string}'" - errors = [ValidationError({field_name: error_msg}) for field_name in uniqueness_constraint_fields] - raise ValidationError(errors) - - async def _check_results( + async def _get_violations( self, updated_node: Node, - path_groups: list[list[SchemaAttributePath]], + uniqueness_constraint_paths: list[SchemaUniquenessConstraintPath], query_results: Iterable[QueryResult], - ) -> None: + ) -> list[UniquenessConstraintViolation]: results_index = UniquenessQueryResultsIndex( query_results=query_results, exclude_node_ids={updated_node.get_id()} ) - for path_group in path_groups: + violations = [] + for uniqueness_constraint_path in uniqueness_constraint_paths: + # path_group = one constraint (that can contain multiple items) schema_attribute_path_values = await self._get_node_attribute_path_values( - updated_node=updated_node, path_group=path_group + updated_node=updated_node, path_group=uniqueness_constraint_path.attributes_paths ) - self._check_one_constraint_group( - schema_attribute_path_values=schema_attribute_path_values, results_index=results_index + + # constraint cannot be violated if this node is missing any values + if any(sapv.value is None for sapv in schema_attribute_path_values): + continue + + matching_node_ids = results_index.get_node_ids_for_value_group(schema_attribute_path_values) + if not matching_node_ids: + continue + + uniqueness_constraint_fields = [] + for sapv in schema_attribute_path_values: + if sapv.relationship_schema: + uniqueness_constraint_fields.append(sapv.relationship_schema.name) + elif sapv.attribute_schema: + uniqueness_constraint_fields.append(sapv.attribute_schema.name) + + violations.append( + UniquenessConstraintViolation( + nodes_ids=matching_node_ids, + fields=uniqueness_constraint_fields, + typ=uniqueness_constraint_path.typ, + ) ) - async def _check_one_schema( + return violations + + async def _get_single_schema_violations( self, node: Node, node_schema: MainSchemaTypes, at: Timestamp | None = None, filters: list[str] | None = None, - ) -> None: + ) -> list[UniquenessConstraintViolation]: schema_branch = self.db.schema.get_schema_branch(name=self.branch.name) - path_groups = node_schema.get_unique_constraint_schema_attribute_paths(schema_branch=schema_branch) + + uniqueness_constraint_paths = node_schema.get_unique_constraint_schema_attribute_paths( + schema_branch=schema_branch + ) query_request = await self._build_query_request( - updated_node=node, node_schema=node_schema, path_groups=path_groups, filters=filters + updated_node=node, + node_schema=node_schema, + uniqueness_constraint_paths=uniqueness_constraint_paths, + filters=filters, ) if not query_request: - return + return [] + query = await NodeUniqueAttributeConstraintQuery.init( db=self.db, branch=self.branch, at=at, query_request=query_request, min_count_required=0 ) await query.execute(db=self.db) - await self._check_results(updated_node=node, path_groups=path_groups, query_results=query.get_results()) + return await self._get_violations( + updated_node=node, + uniqueness_constraint_paths=uniqueness_constraint_paths, + query_results=query.get_results(), + ) async def check(self, node: Node, at: Timestamp | None = None, filters: list[str] | None = None) -> None: def _frozen_constraints(schema: MainSchemaTypes) -> frozenset[frozenset[str]]: @@ -195,7 +214,27 @@ def _frozen_constraints(schema: MainSchemaTypes) -> frozenset[frozenset[str]]: frozen_parent_constraints = _frozen_constraints(parent_schema) if frozen_node_constraints <= frozen_parent_constraints: include_node_schema = False + if include_node_schema: schemas_to_check.append(node_schema) + + violations = [] for schema in schemas_to_check: - await self._check_one_schema(node=node, node_schema=schema, at=at, filters=filters) + schema_violations = await self._get_single_schema_violations( + node=node, node_schema=schema, at=at, filters=filters + ) + violations.extend(schema_violations) + + is_hfid_violated = any(violation.typ == UniquenessConstraintType.HFID for violation in violations) + + for violation in violations: + if violation.typ == UniquenessConstraintType.STANDARD or ( + violation.typ == UniquenessConstraintType.SUBSET_OF_HFID and not is_hfid_violated + ): + error_msg = f"Violates uniqueness constraint '{'-'.join(violation.fields)}'" + raise ValidationError(error_msg) + + for violation in violations: + if violation.typ == UniquenessConstraintType.HFID: + error_msg = f"Violates uniqueness constraint '{'-'.join(violation.fields)}'" + raise HFIDViolatedError(error_msg, matching_nodes_ids=violation.nodes_ids) diff --git a/backend/infrahub/core/schema/basenode_schema.py b/backend/infrahub/core/schema/basenode_schema.py index 47ccbfd206..173cc2b998 100644 --- a/backend/infrahub/core/schema/basenode_schema.py +++ b/backend/infrahub/core/schema/basenode_schema.py @@ -432,25 +432,71 @@ def parse_schema_path(self, path: str, schema: SchemaBranch | None = None) -> Sc def get_unique_constraint_schema_attribute_paths( self, schema_branch: SchemaBranch, - include_unique_attributes: bool = False, - ) -> list[list[SchemaAttributePath]]: - constraint_paths_groups = [] - if include_unique_attributes: - for attribute_schema in self.unique_attributes: - constraint_paths_groups.append( - [SchemaAttributePath(attribute_schema=attribute_schema, attribute_property_name="value")] - ) - - if not self.uniqueness_constraints: - return constraint_paths_groups + ) -> list[SchemaUniquenessConstraintPath]: + if self.uniqueness_constraints is None: + return [] + + uniqueness_constraint_paths = [] for uniqueness_path_group in self.uniqueness_constraints: - constraint_paths_group = [] - for uniqueness_path_part in uniqueness_path_group: - constraint_paths_group.append(self.parse_schema_path(path=uniqueness_path_part, schema=schema_branch)) - if constraint_paths_group not in constraint_paths_groups: - constraint_paths_groups.append(constraint_paths_group) - return constraint_paths_groups + attributes_paths = [ + self.parse_schema_path(path=uniqueness_path_part, schema=schema_branch) + for uniqueness_path_part in uniqueness_path_group + ] + uniqueness_constraint_type = self.get_uniqueness_constraint_type( + uniqueness_constraint=set(uniqueness_path_group), schema_branch=schema_branch + ) + uniqueness_constraint_path = SchemaUniquenessConstraintPath( + attributes_paths=attributes_paths, typ=uniqueness_constraint_type + ) + uniqueness_constraint_paths.append(uniqueness_constraint_path) + + return uniqueness_constraint_paths + + def convert_hfid_to_uniqueness_constraint(self, schema_branch: SchemaBranch) -> list[str] | None: + if self.human_friendly_id is None: + return None + + uniqueness_constraint = [] + for item in self.human_friendly_id: + schema_attribute_path = self.parse_schema_path(path=item, schema=schema_branch) + if schema_attribute_path.is_type_attribute: + uniqueness_constraint.append(item) + elif schema_attribute_path.is_type_relationship: + uniqueness_constraint.append(schema_attribute_path.relationship_schema.name) + return uniqueness_constraint + + def get_uniqueness_constraint_type( + self, uniqueness_constraint: set[str], schema_branch: SchemaBranch + ) -> UniquenessConstraintType: + hfid = self.convert_hfid_to_uniqueness_constraint(schema_branch=schema_branch) + if hfid is None: + return UniquenessConstraintType.STANDARD + hfid_set = set(hfid) + if uniqueness_constraint == hfid_set: + return UniquenessConstraintType.HFID + if uniqueness_constraint <= hfid_set: + return UniquenessConstraintType.SUBSET_OF_HFID + return UniquenessConstraintType.STANDARD + + +@dataclass +class SchemaUniquenessConstraintPath: + attributes_paths: list[SchemaAttributePath] + typ: UniquenessConstraintType + + +class UniquenessConstraintType(Enum): + HFID = "HFID" + SUBSET_OF_HFID = "SUBSET_OF_HFID" + STANDARD = "STANDARD" + + +@dataclass +class UniquenessConstraintViolation: + nodes_ids: set[str] + fields: list[str] + typ: UniquenessConstraintType @dataclass diff --git a/backend/infrahub/core/schema/schema_branch.py b/backend/infrahub/core/schema/schema_branch.py index 046b159633..f5e8778753 100644 --- a/backend/infrahub/core/schema/schema_branch.py +++ b/backend/infrahub/core/schema/schema_branch.py @@ -1174,25 +1174,25 @@ def process_relationships(self) -> None: self.set(name=schema_to_update.kind, schema=schema_to_update) def process_human_friendly_id(self) -> None: + """ + For each schema node, if there is no HFID defined, set it with: + - The first unique attribute if existing + - Otherwise the first uniqueness constraint with a single attribute + + Also, HFID is added to the uniqueness constraints. + """ for name in self.generic_names_without_templates + self.node_names: node = self.get(name=name, duplicate=False) - # If human_friendly_id IS NOT defined - # but some the model has some unique attribute, we generate a human_friendly_id - # If human_friendly_id IS defined - # but no unique attributes and no uniquess constraints, we add a uniqueness_constraint if not node.human_friendly_id: if node.unique_attributes: - for attr in node.unique_attributes: - node = self.get(name=name, duplicate=True) - node.human_friendly_id = [f"{attr.name}__value"] - self.set(name=node.kind, schema=node) - break - continue + node = self.get(name=name, duplicate=True) + node.human_friendly_id = [f"{node.unique_attributes[0].name}__value"] + self.set(name=node.kind, schema=node) # if no human_friendly_id and a uniqueness_constraint with a single attribute exists # then use that attribute as the human_friendly_id - if node.uniqueness_constraints: + elif node.uniqueness_constraints: for constraint_paths in node.uniqueness_constraints: if len(constraint_paths) > 1: continue @@ -1209,22 +1209,15 @@ def process_human_friendly_id(self) -> None: break # Add hfid to uniqueness constraint - if node.human_friendly_id: - uniqueness_constraint: list[str] = [] - for item in node.human_friendly_id: - schema_attribute_path = node.parse_schema_path(path=item, schema=self) - if schema_attribute_path.is_type_attribute: - uniqueness_constraint.append(item) - elif schema_attribute_path.is_type_relationship: - uniqueness_constraint.append(schema_attribute_path.relationship_schema.name) - + hfid_uniqueness_constraint = node.convert_hfid_to_uniqueness_constraint(schema_branch=self) + if hfid_uniqueness_constraint: node = self.get(name=name, duplicate=True) # Make sure there is no duplicate regarding generics values. if node.uniqueness_constraints: - if uniqueness_constraint not in node.uniqueness_constraints: - node.uniqueness_constraints.append(uniqueness_constraint) + if hfid_uniqueness_constraint not in node.uniqueness_constraints: + node.uniqueness_constraints.append(hfid_uniqueness_constraint) else: - node.uniqueness_constraints = [uniqueness_constraint] + node.uniqueness_constraints = [hfid_uniqueness_constraint] self.set(name=node.kind, schema=node) def process_hierarchy(self) -> None: diff --git a/backend/infrahub/core/validators/uniqueness/checker.py b/backend/infrahub/core/validators/uniqueness/checker.py index 82e6529ef1..6d2481e526 100644 --- a/backend/infrahub/core/validators/uniqueness/checker.py +++ b/backend/infrahub/core/validators/uniqueness/checker.py @@ -129,22 +129,22 @@ async def _parse_results(self, schema: MainSchemaTypes, query_results: list[Quer branch = await self.get_branch() schema_branch = self.db.schema.get_schema_branch(name=branch.name) - path_groups = schema.get_unique_constraint_schema_attribute_paths( - include_unique_attributes=True, schema_branch=schema_branch - ) - for constraint_group in path_groups: + uniqueness_constraint_paths = schema.get_unique_constraint_schema_attribute_paths(schema_branch=schema_branch) + for uniqueness_constraint_path in uniqueness_constraint_paths: non_unique_nodes_by_id: dict[str, NonUniqueNode] = {} constraint_group_relationship_identifiers = [ schema_attribute_path.relationship_schema.get_identifier() - for schema_attribute_path in constraint_group + for schema_attribute_path in uniqueness_constraint_path.attributes_paths if schema_attribute_path.relationship_schema ] constraint_group_attribute_names = [ schema_attribute_path.attribute_schema.name - for schema_attribute_path in constraint_group + for schema_attribute_path in uniqueness_constraint_path.attributes_paths if schema_attribute_path.attribute_schema ] - node_ids_in_violation = results_index.get_node_ids_for_path_group(path_group=constraint_group) + node_ids_in_violation = results_index.get_node_ids_for_path_group( + path_group=uniqueness_constraint_path.attributes_paths + ) for result in query_results: node_id = str(result.get("node_id")) if node_id not in node_ids_in_violation: diff --git a/backend/infrahub/dependencies/builder/constraint/grouped/node_runner.py b/backend/infrahub/dependencies/builder/constraint/grouped/node_runner.py index 2a495e079b..0f8ce8ce24 100644 --- a/backend/infrahub/dependencies/builder/constraint/grouped/node_runner.py +++ b/backend/infrahub/dependencies/builder/constraint/grouped/node_runner.py @@ -13,9 +13,7 @@ def build(cls, context: DependencyBuilderContext) -> NodeConstraintRunner: return NodeConstraintRunner( db=context.db, branch=context.branch, - node_constraints=[ - NodeGroupedUniquenessConstraintDependency.build(context=context), - ], + uniqueness_constraint=NodeGroupedUniquenessConstraintDependency.build(context=context), relationship_manager_constraints=[ RelationshipPeerKindConstraintDependency.build(context=context), RelationshipCountConstraintDependency.build(context=context), diff --git a/backend/infrahub/exceptions.py b/backend/infrahub/exceptions.py index 5fc3657ce2..312800a1c3 100644 --- a/backend/infrahub/exceptions.py +++ b/backend/infrahub/exceptions.py @@ -293,33 +293,24 @@ class ValidationError(Error): def __init__(self, input_value: str | dict | list) -> None: self.message = "" - self.location = None - self.messages = {} if isinstance(input_value, str): self.message = input_value - elif isinstance(input_value, dict) and len(input_value) == 1: - self.message = list(input_value.values())[0] - self.location = list(input_value.keys())[0] - elif isinstance(input_value, dict) and len(input_value) > 1: - for key, value in input_value.items(): - self.messages[key] = value - + elif isinstance(input_value, dict): + self.message = ", ".join([f"{message} at {location}" for location, message in input_value.items()]) elif isinstance(input_value, list): - for item in input_value: - if isinstance(item, self.__class__): - self.messages[item.location] = item.message - elif isinstance(item, dict): - for key, value in item.items(): - self.messages[key] = value - - super().__init__(self.message) + if all(isinstance(item, ValidationError) for item in input_value): + self.message = ", ".join([validation_error.message for validation_error in input_value]) + if all(isinstance(item, dict) for item in input_value): + messages = [] + for item in input_value: + messages.append(", ".join([f"{message} at {location}" for location, message in item.items()])) + self.message = ", ".join(messages) - def __str__(self) -> str: - if self.messages: - return ", ".join([f"{message} at {location}" for location, message in self.messages.items()]) + if not self.message: + raise ValueError("Could not build validation error message") - return f"{self.message} at {self.location or ''}" + super().__init__(self.message) class DiffError(Error): @@ -329,6 +320,14 @@ def __init__(self, message: str) -> None: self.message = message +class HFIDViolatedError(ValidationError): + matching_nodes_ids: set[str] + + def __init__(self, input_value: str | dict | list, matching_nodes_ids: set[str]) -> None: + self.matching_nodes_ids = matching_nodes_ids + super().__init__(input_value) + + class DiffRangeValidationError(DiffError): ... diff --git a/backend/infrahub/graphql/mutations/ipam.py b/backend/infrahub/graphql/mutations/ipam.py index 95194a7ef8..9e8af56ba3 100644 --- a/backend/infrahub/graphql/mutations/ipam.py +++ b/backend/infrahub/graphql/mutations/ipam.py @@ -16,11 +16,11 @@ from infrahub.core.schema import NodeSchema from infrahub.database import InfrahubDatabase, retry_db_transaction from infrahub.exceptions import NodeNotFoundError, ValidationError -from infrahub.graphql.mutations.node_getter.interface import MutationNodeGetterInterface from infrahub.lock import InfrahubMultiLock, build_object_lock_name from infrahub.log import get_logger from .main import DeleteResult, InfrahubMutationMixin, InfrahubMutationOptions +from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter if TYPE_CHECKING: from infrahub.graphql.initialization import GraphqlContext @@ -192,21 +192,19 @@ async def mutate_update( ) namespace = await address.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id) - try: - async with db.start_transaction() as dbt: - if lock_name := cls._get_lock_name(namespace_id, branch): - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): - reconciled_address = await cls._mutate_update_object_and_reconcile( - info=info, data=data, branch=branch, address=address, namespace_id=namespace_id, db=dbt - ) - else: + + async with db.start_transaction() as dbt: + if lock_name := cls._get_lock_name(namespace_id, branch): + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): reconciled_address = await cls._mutate_update_object_and_reconcile( info=info, data=data, branch=branch, address=address, namespace_id=namespace_id, db=dbt ) + else: + reconciled_address = await cls._mutate_update_object_and_reconcile( + info=info, data=data, branch=branch, address=address, namespace_id=namespace_id, db=dbt + ) - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address) - except ValidationError as exc: - raise ValueError(str(exc)) from exc + result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_address) return address, result @@ -216,7 +214,7 @@ async def mutate_upsert( info: GraphQLResolveInfo, data: InputObjectType, branch: Branch, - node_getters: list[MutationNodeGetterInterface], + node_getter_default_filter: MutationNodeGetterByDefaultFilter, database: InfrahubDatabase | None = None, ) -> tuple[Node, Self, bool]: graphql_context: GraphqlContext = info.context @@ -224,7 +222,7 @@ async def mutate_upsert( await validate_namespace(db=db, branch=branch, data=data) prefix, result, created = await super().mutate_upsert( - info=info, data=data, branch=branch, node_getters=node_getters, database=db + info=info, data=data, branch=branch, node_getter_default_filter=node_getter_default_filter, database=db ) return prefix, result, created @@ -343,20 +341,18 @@ async def mutate_update( ) namespace = await prefix.ip_namespace.get_peer(db) namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id) - try: - async with db.start_transaction() as dbt: - if lock_name := cls._get_lock_name(namespace_id, branch): - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): - reconciled_prefix = await cls._mutate_update_object_and_reconcile( - info=info, data=data, prefix=prefix, db=dbt, namespace_id=namespace_id, branch=branch - ) - else: + + async with db.start_transaction() as dbt: + if lock_name := cls._get_lock_name(namespace_id, branch): + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): reconciled_prefix = await cls._mutate_update_object_and_reconcile( info=info, data=data, prefix=prefix, db=dbt, namespace_id=namespace_id, branch=branch ) - result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix) - except ValidationError as exc: - raise ValueError(str(exc)) from exc + else: + reconciled_prefix = await cls._mutate_update_object_and_reconcile( + info=info, data=data, prefix=prefix, db=dbt, namespace_id=namespace_id, branch=branch + ) + result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=reconciled_prefix) return prefix, result @@ -366,7 +362,7 @@ async def mutate_upsert( info: GraphQLResolveInfo, data: InputObjectType, branch: Branch, - node_getters: list[MutationNodeGetterInterface], + node_getter_default_filter: MutationNodeGetterByDefaultFilter, database: InfrahubDatabase | None = None, ) -> tuple[Node, Self, bool]: graphql_context: GraphqlContext = info.context @@ -374,7 +370,7 @@ async def mutate_upsert( await validate_namespace(db=db, branch=branch, data=data) prefix, result, created = await super().mutate_upsert( - info=info, data=data, branch=branch, node_getters=node_getters, database=db + info=info, data=data, branch=branch, node_getter_default_filter=node_getter_default_filter, database=db ) return prefix, result, created @@ -414,21 +410,17 @@ async def mutate_delete( namespace_rels = await prefix.ip_namespace.get_relationships(db=db) namespace_id = namespace_rels[0].peer_id - try: - async with graphql_context.db.start_transaction() as dbt: - if lock_name := cls._get_lock_name(namespace_id, branch): - async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): - reconciled_prefix = await cls._reconcile_prefix( - branch=branch, db=dbt, prefix=prefix, namespace_id=namespace_id, is_delete=True - ) - else: + + async with graphql_context.db.start_transaction() as dbt: + if lock_name := cls._get_lock_name(namespace_id, branch): + async with InfrahubMultiLock(lock_registry=lock.registry, locks=[lock_name]): reconciled_prefix = await cls._reconcile_prefix( branch=branch, db=dbt, prefix=prefix, namespace_id=namespace_id, is_delete=True ) - - except ValidationError as exc: - raise ValueError(str(exc)) from exc - + else: + reconciled_prefix = await cls._reconcile_prefix( + branch=branch, db=dbt, prefix=prefix, namespace_id=namespace_id, is_delete=True + ) ok = True return DeleteResult(node=reconciled_prefix, mutation=cls(ok=ok)) diff --git a/backend/infrahub/graphql/mutations/main.py b/backend/infrahub/graphql/mutations/main.py index 2e4e0e924d..29f8dff9d8 100644 --- a/backend/infrahub/graphql/mutations/main.py +++ b/backend/infrahub/graphql/mutations/main.py @@ -22,14 +22,12 @@ from infrahub.database import retry_db_transaction from infrahub.dependencies.registry import get_component_registry from infrahub.events.generator import generate_node_mutation_events -from infrahub.exceptions import InitializationError, ValidationError +from infrahub.exceptions import HFIDViolatedError, InitializationError from infrahub.graphql.context import apply_external_context from infrahub.lock import InfrahubMultiLock, build_object_lock_name from infrahub.log import get_log_data, get_logger from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter -from .node_getter.by_hfid import MutationNodeGetterByHfid -from .node_getter.by_id import MutationNodeGetterById if TYPE_CHECKING: from graphql import GraphQLResolveInfo @@ -42,7 +40,6 @@ from infrahub.graphql.types.context import ContextInput from ..initialization import GraphqlContext - from .node_getter.interface import MutationNodeGetterInterface log = get_logger() @@ -98,13 +95,15 @@ async def mutate( action = MutationAction.UPDATED elif "Upsert" in cls.__name__: node_manager = NodeManager() - node_getters = [ - MutationNodeGetterById(db=graphql_context.db, node_manager=node_manager), - MutationNodeGetterByHfid(db=graphql_context.db, node_manager=node_manager), - MutationNodeGetterByDefaultFilter(db=graphql_context.db, node_manager=node_manager), - ] + node_getter_default_filter = MutationNodeGetterByDefaultFilter( + db=graphql_context.db, node_manager=node_manager + ) obj, mutation, created = await cls.mutate_upsert( - info=info, branch=graphql_context.branch, data=data, node_getters=node_getters, **kwargs + info=info, + branch=graphql_context.branch, + data=data, + node_getter_default_filter=node_getter_default_filter, + **kwargs, ) if created: action = MutationAction.CREATED @@ -311,40 +310,37 @@ async def mutate_create_object( node_class = registry.node[cls._meta.active_schema.kind] fields_to_validate = list(data) - try: - if db.is_transaction: - obj = await node_class.init(db=db, schema=cls._meta.schema, branch=branch) - await obj.new(db=db, **data) + if db.is_transaction: + obj = await node_class.init(db=db, schema=cls._meta.schema, branch=branch) + await obj.new(db=db, **data) + await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) + await obj.save(db=db) + + object_template = await obj.get_object_template(db=db) + if object_template: + await cls._handle_template_relationships( + db=db, + branch=branch, + template=object_template, + obj=obj, + data=data, + ) + else: + async with db.start_transaction() as dbt: + obj = await node_class.init(db=dbt, schema=cls._meta.schema, branch=branch) + await obj.new(db=dbt, **data) await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) - await obj.save(db=db) + await obj.save(db=dbt) - object_template = await obj.get_object_template(db=db) + object_template = await obj.get_object_template(db=dbt) if object_template: await cls._handle_template_relationships( - db=db, + db=dbt, branch=branch, template=object_template, obj=obj, data=data, ) - else: - async with db.start_transaction() as dbt: - obj = await node_class.init(db=dbt, schema=cls._meta.schema, branch=branch) - await obj.new(db=dbt, **data) - await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) - await obj.save(db=dbt) - - object_template = await obj.get_object_template(db=dbt) - if object_template: - await cls._handle_template_relationships( - db=dbt, - branch=branch, - template=object_template, - obj=obj, - data=data, - ) - except ValidationError as exc: - raise ValidationError(input_value=str(exc)) from exc if await cls._get_profile_ids(db=db, obj=obj): obj = await cls._refresh_for_profile_update(db=db, branch=branch, obj=obj) @@ -367,6 +363,7 @@ async def _call_mutate_update( branch: Branch, db: InfrahubDatabase, obj: Node, + run_constraint_checks: bool = True, ) -> tuple[Node, Self]: """ Wrapper around mutate_update to potentially activate locking and call it within a database transaction. @@ -380,18 +377,31 @@ async def _call_mutate_update( if db.is_transaction: if lock_names: async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): - obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) + obj = await cls.mutate_update_object( + db=db, info=info, data=data, branch=branch, obj=obj, run_constraint_checks=run_constraint_checks + ) else: - obj = await cls.mutate_update_object(db=db, info=info, data=data, branch=branch, obj=obj) + obj = await cls.mutate_update_object( + db=db, info=info, data=data, branch=branch, obj=obj, run_constraint_checks=run_constraint_checks + ) result = await cls.mutate_update_to_graphql(db=db, info=info, obj=obj) return obj, result async with db.start_transaction() as dbt: if lock_names: async with InfrahubMultiLock(lock_registry=lock.registry, locks=lock_names): - obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj) + obj = await cls.mutate_update_object( + db=dbt, + info=info, + data=data, + branch=branch, + obj=obj, + run_constraint_checks=run_constraint_checks, + ) else: - obj = await cls.mutate_update_object(db=dbt, info=info, data=data, branch=branch, obj=obj) + obj = await cls.mutate_update_object( + db=dbt, info=info, data=data, branch=branch, obj=obj, run_constraint_checks=run_constraint_checks + ) result = await cls.mutate_update_to_graphql(db=dbt, info=info, obj=obj) return obj, result @@ -412,10 +422,7 @@ async def mutate_update( db=db, kind=cls._meta.active_schema.kind, id=data.get("id"), hfid=data.get("hfid"), branch=branch ) - try: - obj, result = await cls._call_mutate_update(info=info, data=data, db=db, branch=branch, obj=obj) - except ValidationError as exc: - raise ValueError(str(exc)) from exc + obj, result = await cls._call_mutate_update(info=info, data=data, db=db, branch=branch, obj=obj) return obj, result @@ -427,6 +434,7 @@ async def mutate_update_object( data: InputObjectType, branch: Branch, obj: Node, + run_constraint_checks: bool = True, ) -> Node: component_registry = get_component_registry() node_constraint_runner = await component_registry.get_component(NodeConstraintRunner, db=db, branch=branch) @@ -434,7 +442,8 @@ async def mutate_update_object( before_mutate_profile_ids = await cls._get_profile_ids(db=db, obj=obj) await obj.from_graphql(db=db, data=data) fields_to_validate = list(data) - await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) + if run_constraint_checks: + await node_constraint_runner.check(node=obj, field_filters=fields_to_validate) fields = list(data.keys()) for field_to_remove in ("id", "hfid"): @@ -469,31 +478,76 @@ async def mutate_upsert( info: GraphQLResolveInfo, data: InputObjectType, branch: Branch, - node_getters: list[MutationNodeGetterInterface], + node_getter_default_filter: MutationNodeGetterByDefaultFilter, database: InfrahubDatabase | None = None, ) -> tuple[Node, Self, bool]: + """ + First, check whether payload contains data identifying the node, such as id, hfid, or relevant fields for + default_filter. If not, we will try to create the node, but this creation might fail if payload contains + hfid fields (not `hfid` field itself) that would match an existing node in the database. In that case, + we would update the node without rerunning uniqueness constraint. + """ + schema_name = cls._meta.active_schema.kind graphql_context: GraphqlContext = info.context db = database or graphql_context.db - - node_schema = db.schema.get(name=schema_name, branch=branch) - + dict_data = dict(data) node = None - for getter in node_getters: - node = await getter.get_node(node_schema=node_schema, data=data, branch=branch) - if node: - break + run_constraint_checks = True - if node: - updated_obj, mutation = await cls.mutate_update(info=info, data=data, branch=branch, database=db, node=node) + if "id" in dict_data: + node = await NodeManager.get_one( + db=db, id=dict_data["id"], kind=schema_name, branch=branch, raise_on_error=True + ) + updated_obj, mutation = await cls._call_mutate_update( + info=info, + data=data, + db=db, + branch=branch, + obj=node, + run_constraint_checks=run_constraint_checks, + ) return updated_obj, mutation, False - # We need to convert the InputObjectType into a dict in order to remove hfid that isn't a valid input when creating the object - data_dict = dict(data) + + if cls._meta.active_schema.default_filter is not None: + node = await node_getter_default_filter.get_node( + node_schema=cls._meta.active_schema, data=data, branch=branch + ) + if "hfid" in data: - del data_dict["hfid"] - created_obj, mutation = await cls.mutate_create(info=info, data=data_dict, branch=branch) - return created_obj, mutation, True + node = await NodeManager.get_one_by_hfid(db=db, hfid=dict_data["hfid"], kind=schema_name, branch=branch) + + if node is not None: + updated_obj, mutation = await cls._call_mutate_update( + info=info, + data=data, + db=db, + branch=branch, + obj=node, + run_constraint_checks=run_constraint_checks, + ) + return updated_obj, mutation, False + + try: + dict_data.pop("hfid", "unused") # `hfid` is invalid for creation. + created_obj, mutation = await cls.mutate_create(info=info, data=dict_data, branch=branch) + return created_obj, mutation, True + except HFIDViolatedError as exc: + # Only the HFID constraint has been violated, it means the node exists and we can update without rerunning constraints + if len(exc.matching_nodes_ids) > 1: + raise RuntimeError(f"Multiple {schema_name} nodes have the same hfid (database corrupted)") from exc + node_id = list(exc.matching_nodes_ids)[0] + node = await NodeManager.get_one(db=db, id=node_id, kind=schema_name, branch=branch, raise_on_error=True) + updated_obj, mutation = await cls._call_mutate_update( + info=info, + data=data, + db=db, + branch=branch, + obj=node, + run_constraint_checks=run_constraint_checks, + ) + return updated_obj, mutation, False @classmethod @retry_db_transaction(name="object_delete") @@ -513,11 +567,8 @@ async def mutate_delete( branch=branch, ) - try: - async with graphql_context.db.start_transaction() as db: - deleted = await NodeManager.delete(db=db, branch=branch, nodes=[obj]) - except ValidationError as exc: - raise ValueError(str(exc)) from exc + async with graphql_context.db.start_transaction() as db: + deleted = await NodeManager.delete(db=db, branch=branch, nodes=[obj]) deleted_str = ", ".join([f"{d.get_kind()}({d.get_id()})" for d in deleted]) log.info(f"nodes deleted: {deleted_str}") diff --git a/backend/infrahub/graphql/mutations/node_getter/by_default_filter.py b/backend/infrahub/graphql/mutations/node_getter/by_default_filter.py index ef8aebd1bd..4d37144d1a 100644 --- a/backend/infrahub/graphql/mutations/node_getter/by_default_filter.py +++ b/backend/infrahub/graphql/mutations/node_getter/by_default_filter.py @@ -1,3 +1,5 @@ +from copy import copy + from graphene import InputObjectType from infrahub.core.branch import Branch @@ -20,20 +22,20 @@ async def get_node( data: InputObjectType, branch: Branch, ) -> Node | None: - node = None - default_filter_value = None if not node_schema.default_filter: - return node - this_datum = data + return None + + data = copy(data) for filter_key in node_schema.default_filter.split("__"): - if filter_key not in this_datum: + if filter_key not in data: break - this_datum = this_datum[filter_key] - default_filter_value = this_datum + data = data[filter_key] + + default_filter_value = data if not default_filter_value: - return node + return None return await self.node_manager.get_one_by_default_filter( db=self.db, diff --git a/backend/infrahub/server.py b/backend/infrahub/server.py index 6305c17dab..9b8416e423 100644 --- a/backend/infrahub/server.py +++ b/backend/infrahub/server.py @@ -17,7 +17,6 @@ from infrahub_sdk.exceptions import TimestampFormatError from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.trace import Span -from pydantic import ValidationError from starlette_exporter import PrometheusMiddleware, handle_metrics from infrahub import __version__, config @@ -28,7 +27,7 @@ from infrahub.core.initialization import initialization from infrahub.database import InfrahubDatabase, InfrahubDatabaseMode, get_db from infrahub.dependencies.registry import build_component_registry -from infrahub.exceptions import Error +from infrahub.exceptions import Error, ValidationError from infrahub.graphql.api.endpoints import router as graphql_router from infrahub.lock import initialize_lock from infrahub.log import clear_log_context, get_logger, set_log_data diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index c4ef6ca081..0ab7e4aad9 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -526,6 +526,31 @@ async def car_person_schema_unregistered(db: InfrahubDatabase, node_group_schema return SchemaRoot(**schema) +@pytest.fixture +async def person_schema_default_filter(db: InfrahubDatabase, node_group_schema, data_schema) -> SchemaRoot: + """ + Person schema with no unicity constraint set except default filter. + """ + + schema: dict[str, Any] = { + "nodes": [ + { + "name": "PersonDF", + "namespace": "Test", + "default_filter": "name__value", + "display_labels": ["name__value"], + "branch": BranchSupportType.AWARE.value, + "attributes": [ + {"name": "name", "kind": "Text"}, + {"name": "height", "kind": "Number", "optional": True}, + ], + }, + ], + } + + return SchemaRoot(**schema) + + @pytest.fixture async def car_person_schema( db: InfrahubDatabase, default_branch: Branch, car_person_schema_unregistered @@ -656,6 +681,7 @@ async def animal_person_schema_unregistered(db: InfrahubDatabase, node_group_sch "attributes": [ {"name": "name", "kind": "Text", "unique": True}, {"name": "height", "kind": "Number", "optional": True}, + {"name": "bag", "kind": "Text", "optional": True, "unique": True}, ], "relationships": [ { @@ -763,7 +789,6 @@ async def dependent_generics_unregistered(db: InfrahubDatabase, node_group_schem "namespace": "Test", "display_labels": ["name__value"], "inherit_from": ["TestPerson"], - "default_filter": "name__value", "human_friendly_id": ["name__value"], }, { @@ -771,7 +796,6 @@ async def dependent_generics_unregistered(db: InfrahubDatabase, node_group_schem "namespace": "Test", "display_labels": ["name__value"], "inherit_from": ["TestPerson"], - "default_filter": "name__value", "human_friendly_id": ["name__value"], "attributes": [ {"name": "model_number", "kind": "Number", "optional": False}, @@ -972,8 +996,8 @@ def car_person_branch_agnostic_schema() -> dict[str, Any]: { "name": "Car", "namespace": "Test", - "default_filter": "name__value", - "uniqueness_constraints": [["name__value"]], + "uniqueness_constraints": [["agnostic_owner"]], + "human_friendly_id": ["name__value"], "branch": BranchSupportType.AGNOSTIC.value, "attributes": [ {"name": "name", "kind": "Text", "unique": True}, diff --git a/backend/tests/node_creation.py b/backend/tests/node_creation.py new file mode 100644 index 0000000000..c58d1bb506 --- /dev/null +++ b/backend/tests/node_creation.py @@ -0,0 +1,12 @@ +from typing import Any + +from infrahub.core.branch import Branch +from infrahub.core.node import Node +from infrahub.database import InfrahubDatabase + + +async def create_and_save(db: InfrahubDatabase, schema: str, branch: Branch | str | None = None, **kwargs: Any) -> Node: + node = await Node.init(db=db, schema=schema, branch=branch) + await node.new(db=db, **kwargs) + await node.save(db=db) + return node diff --git a/backend/tests/unit/core/changelog/test_models.py b/backend/tests/unit/core/changelog/test_models.py index d810ef7d71..b191f6ed03 100644 --- a/backend/tests/unit/core/changelog/test_models.py +++ b/backend/tests/unit/core/changelog/test_models.py @@ -66,6 +66,24 @@ async def test_node_changelog_creation(db: InfrahubDatabase, default_branch, ani }, kind="Number", ), + "bag": AttributeChangelog( + name="bag", + value=None, + value_previous=None, + properties={ + "is_protected": PropertyChangelog( + name="is_protected", + value=False, + value_previous=None, + ), + "is_visible": PropertyChangelog( + name="is_visible", + value=True, + value_previous=None, + ), + }, + kind="Text", + ), }, relationships={}, ) diff --git a/backend/tests/unit/core/constraint_validators/conftest.py b/backend/tests/unit/core/constraint_validators/conftest.py index 3e93fe4c03..42d4c1de82 100644 --- a/backend/tests/unit/core/constraint_validators/conftest.py +++ b/backend/tests/unit/core/constraint_validators/conftest.py @@ -116,3 +116,46 @@ async def car_person_generics_data_simple(db: InfrahubDatabase, car_person_schem } return nodes + + +@pytest.fixture +async def car_person_schema_hfid(db: InfrahubDatabase, default_branch: Branch) -> SchemaRoot: + SCHEMA = { + "nodes": [ + { + "name": "Car", + "namespace": "Test", + "human_friendly_id": ["name__value", "owner__name__value"], + "order_by": ["name__value"], + "attributes": [ + {"name": "name", "kind": "Text", "unique": True}, + ], + "relationships": [ + { + "name": "owner", + "peer": "TestPerson", + "identifier": "person__car", + "optional": False, + "cardinality": "one", + }, + ], + }, + { + "name": "Person", + "namespace": "Test", + "human_friendly_id": ["name__value"], + "display_labels": ["name__value"], + "branch": BranchSupportType.AWARE.value, + "attributes": [ + {"name": "name", "kind": "Text", "unique": True}, + ], + "relationships": [ + {"name": "cars", "peer": "TestCar", "identifier": "person__car", "cardinality": "many"} + ], + }, + ], + } + + schema = SchemaRoot(**SCHEMA) + registry.schema.register_schema(schema=schema, branch=default_branch.name) + return schema diff --git a/backend/tests/unit/core/constraint_validators/test_node_grouped_uniqueness.py b/backend/tests/unit/core/constraint_validators/test_node_grouped_uniqueness.py index 7e7c3ddb69..5a48fb3413 100644 --- a/backend/tests/unit/core/constraint_validators/test_node_grouped_uniqueness.py +++ b/backend/tests/unit/core/constraint_validators/test_node_grouped_uniqueness.py @@ -8,7 +8,8 @@ from infrahub.core.node.constraints.grouped_uniqueness import NodeGroupedUniquenessConstraint from infrahub.core.validators.uniqueness.query import NodeUniqueAttributeConstraintQuery from infrahub.database import InfrahubDatabase -from infrahub.exceptions import ValidationError +from infrahub.exceptions import HFIDViolatedError, ValidationError +from tests.node_creation import create_and_save class TestNodeGroupedUniquenessConstraint: @@ -327,3 +328,21 @@ async def test_generic_constraints_failure( with pytest.raises(ValidationError, match="Violates uniqueness constraint 'color-owner'"): await self.__call_system_under_test(db=db, branch=default_branch, node=car_node_1) + + async def test_hfid_violated(self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_hfid): + person_john = await create_and_save(db=db, schema="TestPerson", name="John") + _ = await create_and_save(db=db, schema="TestCar", name="mercedes", owner=person_john) + car_mercedes_2 = await create_and_save(db=db, schema="TestCar", name="mercedes", owner=person_john) + + with pytest.raises(HFIDViolatedError, match="Violates uniqueness constraint 'name-owner'"): + await self.__call_system_under_test(db=db, branch=default_branch, node=car_mercedes_2) + + async def test_subset_hfid_violated(self, db: InfrahubDatabase, default_branch: Branch, car_person_schema_hfid): + person_john = await create_and_save(db=db, schema="TestPerson", name="John") + person_maria = await create_and_save(db=db, schema="TestPerson", name="Maria") + _ = await create_and_save(db=db, schema="TestCar", name="mercedes", owner=person_john) + car_mercedes_of_maria = await create_and_save(db=db, schema="TestCar", name="mercedes", owner=person_maria) + + with pytest.raises(ValidationError, match="Violates uniqueness constraint 'name'") as exc_info: + await self.__call_system_under_test(db=db, branch=default_branch, node=car_mercedes_of_maria) + assert not isinstance(exc_info.value, HFIDViolatedError), "HFIDViolatedError should not be raised here" diff --git a/backend/tests/unit/core/constraint_validators/test_node_uniqueness.py b/backend/tests/unit/core/constraint_validators/test_node_uniqueness.py index c9bd9cc768..6b2ea89a82 100644 --- a/backend/tests/unit/core/constraint_validators/test_node_uniqueness.py +++ b/backend/tests/unit/core/constraint_validators/test_node_uniqueness.py @@ -69,5 +69,5 @@ async def test_hierarchical_uniqueness_constraint( ld62 = await Node.init(db=db, schema="LocationRack", branch=default_branch) await ld62.new(db=db, name="ld6-ldn2", parent=uk) - with pytest.raises(ValidationError, match=r"Violates uniqueness constraint 'parent-status' at status"): + with pytest.raises(ValidationError, match=r"Violates uniqueness constraint 'parent-status'"): await constraint.check(ld62) diff --git a/backend/tests/unit/core/constraint_validators/test_uniqueness_checker.py b/backend/tests/unit/core/constraint_validators/test_uniqueness_checker.py index 7ba4a59e65..8c940e902d 100644 --- a/backend/tests/unit/core/constraint_validators/test_uniqueness_checker.py +++ b/backend/tests/unit/core/constraint_validators/test_uniqueness_checker.py @@ -55,7 +55,8 @@ async def test_one_violation( schema_root = SchemaRoot(nodes=[schema]) registry.schema.register_schema(schema=schema_root, branch=branch.name) - grouped_data_paths = await self.__call_system_under_test(db, branch, schema) + schema_uniqueness_constraint_synced = registry.schema.get(name="TestCar", branch=branch) + grouped_data_paths = await self.__call_system_under_test(db, branch, schema_uniqueness_constraint_synced) assert len(grouped_data_paths) == 1 all_data_paths = grouped_data_paths[0].get_all_data_paths() diff --git a/backend/tests/unit/graphql/mutations/test_ipam.py b/backend/tests/unit/graphql/mutations/test_ipam.py index 3f08427bae..4248f51070 100644 --- a/backend/tests/unit/graphql/mutations/test_ipam.py +++ b/backend/tests/unit/graphql/mutations/test_ipam.py @@ -48,6 +48,30 @@ } """ + +UPSERT_IPPREFIX_NO_ID = """ +mutation UpsertPrefix($prefix: String!, $description: String!) { + IpamIPPrefixUpsert( + data: { + prefix: { + value: $prefix + } + description: { + value: $description + } + } + ) { + ok + object { + id + description { + value + } + } + } +} +""" + UPSERT_IPPREFIX = """ mutation UpsertPrefix($id: String!, $prefix: String!, $description: String!) { IpamIPPrefixUpsert( @@ -168,6 +192,29 @@ } """ +UPSERT_IPADDRESS_NO_ID = """ +mutation UpsertAddress($address: String!, $description: String!) { + IpamIPAddressUpsert( + data: { + address: { + value: $address + } + description: { + value: $description + } + } + ) { + ok + object { + id + description { + value + } + } + } +} +""" + UPSERT_IPADDRESS = """ mutation UpsertAddress($id: String!, $address: String!, $description: String!) { IpamIPAddressUpsert( @@ -548,9 +595,9 @@ async def test_ipprefix_upsert( subnet = ipaddress.ip_network("2001:db8::/48") result = await graphql( schema=gql_params.schema, - source=UPSERT_IPPREFIX, + source=UPSERT_IPPREFIX_NO_ID, context_value=gql_params.context, - variable_values={"id": "", "prefix": str(subnet), "description": ""}, + variable_values={"prefix": str(subnet), "description": ""}, ) assert not result.errors @@ -853,9 +900,9 @@ async def test_ipaddress_upsert( address = ipaddress.ip_interface("192.0.2.1/24") result = await graphql( schema=gql_params.schema, - source=UPSERT_IPADDRESS, + source=UPSERT_IPADDRESS_NO_ID, context_value=gql_params.context, - variable_values={"id": "", "address": str(address), "description": ""}, + variable_values={"address": str(address), "description": ""}, ) assert not result.errors diff --git a/backend/tests/unit/graphql/test_mutation_relationship.py b/backend/tests/unit/graphql/test_mutation_relationship.py index 8425e8bdd0..72630132bc 100644 --- a/backend/tests/unit/graphql/test_mutation_relationship.py +++ b/backend/tests/unit/graphql/test_mutation_relationship.py @@ -285,7 +285,7 @@ async def test_relationship_wrong_name( ) assert result.errors - assert result.errors[0].message == "'notvalid' is not a valid relationship for 'TestPerson'" + assert result.errors[0].message == "'notvalid' is not a valid relationship for 'TestPerson' at name" # Relationship existing relationship with the wrong cardinality query = """ @@ -313,7 +313,7 @@ async def test_relationship_wrong_name( ) assert result.errors - assert result.errors[0].message == "'primary_tag' must be a relationship of cardinality Many" + assert result.errors[0].message == "'primary_tag' must be a relationship of cardinality Many at name" async def test_relationship_wrong_node( diff --git a/backend/tests/unit/graphql/test_mutation_upsert.py b/backend/tests/unit/graphql/test_mutation_upsert.py index 34a39b36c5..3f690a3747 100644 --- a/backend/tests/unit/graphql/test_mutation_upsert.py +++ b/backend/tests/unit/graphql/test_mutation_upsert.py @@ -1,5 +1,3 @@ -from uuid import uuid4 - from infrahub.auth import AccountSession from infrahub.core.branch import Branch from infrahub.core.manager import NodeManager @@ -14,6 +12,7 @@ from tests.constants import TestKind from tests.helpers.graphql import graphql from tests.helpers.schema import TICKET +from tests.node_creation import create_and_save async def test_upsert_existing_simple_object_by_id(db: InfrahubDatabase, person_john_main: Node, branch: Branch): @@ -46,16 +45,29 @@ async def test_upsert_existing_simple_object_by_id(db: InfrahubDatabase, person_ async def test_upsert_existing_simple_object_by_default_filter( - db: InfrahubDatabase, person_john_main: Node, branch: Branch + db: InfrahubDatabase, person_schema_default_filter, default_branch ): + registry.schema.register_schema(schema=person_schema_default_filter) + + person = await Node.init(db=db, schema="TestPersonDF") + await person.new(db=db, name="John", height=180) + await person.save(db=db) + query = """ mutation { - TestPersonUpsert(data: {name: { value: "John"}, height: {value: 138}}) { + TestPersonDFUpsert(data: {name: { value: "John"}, height: {value: 138}}) { ok + object { + id + name { + value + } + } } } """ - gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch) + + gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=default_branch) result = await graphql( schema=gql_params.schema, source=query, @@ -66,9 +78,10 @@ async def test_upsert_existing_simple_object_by_default_filter( assert result.errors is None assert result.data - assert result.data["TestPersonUpsert"]["ok"] is True + assert result.data["TestPersonDFUpsert"]["ok"] is True + assert result.data["TestPersonDFUpsert"]["object"]["id"] == person.id - obj1 = await NodeManager.get_one(db=db, id=person_john_main.id, branch=branch) + obj1 = await NodeManager.get_one(db=db, id=person.id) assert obj1.name.value == "John" assert obj1.height.value == 138 @@ -171,50 +184,18 @@ async def test_upsert_create_simple_object_no_id(db: InfrahubDatabase, person_jo assert obj1.height.value == 179 -async def test_upsert_create_simple_object_with_id(db: InfrahubDatabase, person_john_main, branch: Branch): - fresh_id = str(uuid4()) - query = """ - mutation { - TestPersonUpsert(data: {id: "%s", name: { value: "%s"}, height: {value: %s}}) { - ok - object { - id - } - } - } - """ % (fresh_id, "Dwayne Hicks", 168) - - gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch) - result = await graphql( - schema=gql_params.schema, - source=query, - context_value=gql_params.context, - root_value=None, - variable_values={}, - ) - - assert result.errors is None - assert result.data - assert result.data["TestPersonUpsert"]["ok"] is True - - person_id = result.data["TestPersonUpsert"]["object"]["id"] - assert person_id == fresh_id - obj1 = await NodeManager.get_one(db=db, id=person_id, branch=branch) - assert obj1.name.value == "Dwayne Hicks" - assert obj1.height.value == 168 - - -async def test_cannot_upsert_new_object_without_required_fields(db: InfrahubDatabase, person_john_main, branch: Branch): - fresh_id = str(uuid4()) +async def test_id_for_other_schema_raises_error( + db: InfrahubDatabase, person_john_main, car_accord_main, branch: Branch +): query = ( """ mutation { - TestPersonUpsert(data: {id: "%s", height: { value: 182}}) { + TestPersonUpsert(data: {id: "%s", name: {value: "John"}, height: { value: 182}}) { ok } } """ - % fresh_id + % car_accord_main.id ) gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch) result = await graphql( @@ -225,25 +206,23 @@ async def test_cannot_upsert_new_object_without_required_fields(db: InfrahubData variable_values={}, ) - expected_error = "Field 'TestPersonUpsertInput.name' of required type 'TextAttributeUpdate!' was not provided." + expected_error = f"Node with id {car_accord_main.id} exists, but it is a TestCar, not TestPerson" assert result.errors assert any(expected_error in error.message for error in result.errors) - assert await NodeManager.get_one(db=db, id=fresh_id, branch=branch) is None - -async def test_id_for_other_schema_raises_error( - db: InfrahubDatabase, person_john_main, car_accord_main, branch: Branch +async def test_update_by_id_to_nonunique_value_raises_error( + db: InfrahubDatabase, person_john_main, person_jim_main, branch: Branch ): query = ( """ mutation { - TestPersonUpsert(data: {id: "%s", name: {value: "John"}, height: { value: 182}}) { + TestPersonUpsert(data: {id: "%s", name: {value: "Jim"}}) { ok } } """ - % car_accord_main.id + % person_john_main.id ) gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch) result = await graphql( @@ -254,24 +233,23 @@ async def test_id_for_other_schema_raises_error( variable_values={}, ) - expected_error = f"Node with id {car_accord_main.id} exists, but it is a TestCar, not TestPerson" + expected_error = "Violates uniqueness constraint 'name'" assert result.errors assert any(expected_error in error.message for error in result.errors) -async def test_update_by_id_to_nonunique_value_raises_error( - db: InfrahubDatabase, person_john_main, person_jim_main, branch: Branch -): - query = ( - """ +async def test_non_unique_value_raises_error(db: InfrahubDatabase, animal_person_schema, branch: Branch): + _ = await create_and_save(db=db, schema="TestPerson", name="Jack", bag="bag-jacks") + + # Make sure correct raised error is raised while violating uniqueness constraint of a non hfid-related attribute. + query = """ mutation { - TestPersonUpsert(data: {id: "%s", name: {value: "Jim"}}) { + TestPersonUpsert(data: {name: {value: "Jim"}, bag: {value: "bag-jacks"}}) { ok } } """ - % person_john_main.id - ) + gql_params = await prepare_graphql_params(db=db, include_subscription=False, branch=branch) result = await graphql( schema=gql_params.schema, @@ -280,10 +258,8 @@ async def test_update_by_id_to_nonunique_value_raises_error( root_value=None, variable_values={}, ) - - expected_error = "Violates uniqueness constraint 'name' at name" - assert result.errors - assert any(expected_error in error.message for error in result.errors) + assert len(result.errors) == 1 + assert "Violates uniqueness constraint 'bag'" in result.errors[0].message async def test_with_hfid_existing(db: InfrahubDatabase, default_branch, animal_person_schema): @@ -397,6 +373,7 @@ async def test_with_hfid_new(db: InfrahubDatabase, default_branch, animal_person async def test_with_constructed_hfid(db: InfrahubDatabase, default_branch, animal_person_schema) -> None: """Validate that we can construct an HFID out of the payload without specifying all parts.""" + person_schema = animal_person_schema.get(name="TestPerson") person1 = await Node.init(db=db, schema=person_schema, branch=default_branch) @@ -482,6 +459,7 @@ async def test_with_constructed_hfid_with_numbers( db: InfrahubDatabase, default_branch: Branch, data_schema: None ) -> None: """Validate that we can construct an HFID out of the payload without specifying all parts.""" + registry.schema.register_schema(schema=SchemaRoot(nodes=[TICKET]), branch=default_branch.name) first_ticket = await Node.init(schema=TestKind.TICKET, db=db) diff --git a/changelog/+schema_strict_mode.added.md b/changelog/+schemastrictmode.added.md similarity index 100% rename from changelog/+schema_strict_mode.added.md rename to changelog/+schemastrictmode.added.md diff --git a/changelog/upsertperformances.changed.md b/changelog/upsertperformances.changed.md new file mode 100644 index 0000000000..e69de29bb2