diff --git a/backend/infrahub/computed_attribute/models.py b/backend/infrahub/computed_attribute/models.py index b77862acdc..409ee059ca 100644 --- a/backend/infrahub/computed_attribute/models.py +++ b/backend/infrahub/computed_attribute/models.py @@ -1,11 +1,18 @@ +from __future__ import annotations + from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING -from prefect.events.schemas.automations import Automation +from prefect.events.schemas.automations import Automation # noqa: TCH002 from pydantic import BaseModel, Field from typing_extensions import Self from .constants import AUTOMATION_NAME_PREFIX +if TYPE_CHECKING: + from infrahub.core.schema.schema_branch_computed import PythonDefinition + class ComputedAttributeAutomations(BaseModel): data: dict[str, dict[str, Automation]] = Field(default_factory=lambda: defaultdict(dict)) @@ -37,3 +44,14 @@ def has(self, identifier: str, scope: str) -> bool: if identifier in self.data and scope in self.data[identifier]: return True return False + + +@dataclass +class PythonTransformComputedAttribute: + name: str + repository_id: str + repository_name: str + repository_kind: str + query_name: str + query_models: list[str] + computed_attribute: PythonDefinition diff --git a/backend/infrahub/computed_attribute/tasks.py b/backend/infrahub/computed_attribute/tasks.py index 7aab078063..97c8a491a8 100644 --- a/backend/infrahub/computed_attribute/tasks.py +++ b/backend/infrahub/computed_attribute/tasks.py @@ -17,11 +17,12 @@ from infrahub.git.repository import get_initialized_repo from infrahub.services import services from infrahub.support.macro import MacroDefinition +from infrahub.tasks.registry import refresh_branches from infrahub.workflows.catalogue import PROCESS_COMPUTED_MACRO, UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM from infrahub.workflows.utils import add_branch_tag from .constants import AUTOMATION_NAME, AUTOMATION_NAME_PREFIX -from .models import ComputedAttributeAutomations +from .models import ComputedAttributeAutomations, PythonTransformComputedAttribute if TYPE_CHECKING: from infrahub.core.schema.computed_attribute import ComputedAttribute @@ -48,6 +49,8 @@ async def process_transform( branch_name: str, node_kind: str, object_id: str, + computed_attribute_name: str, # pylint: disable=unused-argument + computed_attribute_kind: str, # pylint: disable=unused-argument updated_fields: list[str] | None = None, # pylint: disable=unused-argument ) -> None: """Request to the creation of git branches in available repositories.""" @@ -262,3 +265,114 @@ async def computed_attribute_setup() -> None: else: await client.create_automation(automation=automation) log.info(f"{computed_attribute.key_name} Created") + + +@flow( + name="computed-attribute-setup-python", + flow_run_name="Setup computed attributes for Python transforms in task-manager", +) +async def computed_attribute_setup_python() -> None: + service = services.service + async with service.database.start_session() as db: + await refresh_branches(db=db) + + schema_branch = registry.schema.get_schema_branch(name=registry.default_branch) + log = get_run_logger() + + transform_attributes = schema_branch.computed_attributes.python_attributes_by_transform + + transform_names = list(transform_attributes.keys()) + + transforms = await service.client.filters( + kind="CoreTransformPython", + branch=registry.default_branch, + prefetch_relationships=True, + populate_store=True, + name__values=transform_names, + ) + + found_transforms_names = [transform.name.value for transform in transforms] + for transform_name in transform_names: + if transform_name not in found_transforms_names: + log.warning( + msg=f"The transform {transform_name} is assigned to a computed attribute but the transform could not be found in the database." + ) + + computed_attributes: list[PythonTransformComputedAttribute] = [] + for transform in transforms: + for attribute in transform_attributes[transform.name.value]: + computed_attributes.append( + PythonTransformComputedAttribute( + name=transform.name.value, + repository_id=transform.repository.peer.id, + repository_name=transform.repository.peer.name.value, + repository_kind=transform.repository.peer.typename, + query_name=transform.query.peer.name.value, + query_models=transform.query.peer.models.value, + computed_attribute=attribute, + ) + ) + + async with get_client(sync_client=False) as client: + deployments = { + item.name: item + for item in await client.read_deployments( + deployment_filter=DeploymentFilter( + name=DeploymentFilterName(any_=[UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name]) + ) + ) + } + if UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name not in deployments: + raise ValueError("Unable to find the deployment for UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM") + + deployment_id_python = deployments[UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name].id + + automations = await client.read_automations() + existing_computed_attr_automations = ComputedAttributeAutomations.from_prefect(automations=automations) + + for computed_attribute in computed_attributes: + log.info(f"processing {computed_attribute.computed_attribute.key_name}") + scope = "default" + + automation = AutomationCore( + name=AUTOMATION_NAME.format( + prefix=AUTOMATION_NAME_PREFIX, + identifier=computed_attribute.computed_attribute.key_name, + scope=scope, + ), + description=f"Process value of the computed attribute for {computed_attribute.computed_attribute.key_name} [{scope}]", + enabled=True, + trigger=EventTrigger( + posture=Posture.Reactive, + expect={"infrahub.node.*"}, + within=timedelta(0), + match=ResourceSpecification({"infrahub.node.kind": [computed_attribute.computed_attribute.kind]}), + threshold=1, + ), + actions=[ + RunDeployment( + source="selected", + deployment_id=deployment_id_python, + parameters={ + "branch_name": "{{ event.resource['infrahub.branch.name'] }}", + "node_kind": "{{ event.resource['infrahub.node.kind'] }}", + "object_id": "{{ event.resource['infrahub.node.id'] }}", + "computed_attribute_name": computed_attribute.computed_attribute.attribute.name, + "computed_attribute_kind": computed_attribute.computed_attribute.kind, + }, + job_variables={}, + ) + ], + ) + + if existing_computed_attr_automations.has( + identifier=computed_attribute.computed_attribute.key_name, scope=scope + ): + existing = existing_computed_attr_automations.get( + identifier=computed_attribute.computed_attribute.key_name, scope=scope + ) + await client.update_automation(automation_id=existing.id, automation=automation) + log.info(f"{computed_attribute.computed_attribute.key_name} Updated") + else: + await client.create_automation(automation=automation) + log.info(f"{computed_attribute.computed_attribute.key_name} Created") diff --git a/backend/infrahub/core/schema/schema_branch.py b/backend/infrahub/core/schema/schema_branch.py index 1159f89245..0531960594 100644 --- a/backend/infrahub/core/schema/schema_branch.py +++ b/backend/infrahub/core/schema/schema_branch.py @@ -54,6 +54,7 @@ from infrahub.visuals import select_color from .constants import INTERNAL_SCHEMA_NODE_KINDS, SchemaNamespace +from .schema_branch_computed import ComputedAttributes log = get_logger() @@ -114,6 +115,7 @@ def __init__(self, cache: dict, name: str | None = None, data: dict[str, dict[st self.nodes: dict[str, str] = {} self.generics: dict[str, str] = {} self.profiles: dict[str, str] = {} + self.computed_attributes = ComputedAttributes() self._computed_jinja2_attribute_map: dict[str, RegisteredNodeComputedAttribute] = {} if data: @@ -949,6 +951,7 @@ def validate_kinds(self) -> None: ) from None def validate_computed_attributes(self) -> None: + self.computed_attributes = ComputedAttributes() self._computed_jinja2_attribute_map = {} for name in self.nodes.keys(): node_schema = self.get_node(name=name, duplicate=False) @@ -1016,11 +1019,14 @@ def _validate_computed_attribute(self, node: NodeSchema, attribute: AttributeSch self._register_computed_attribute_target(node=node, attribute=attribute, schema_path=schema_path) - if attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON and not attribute.optional: + elif attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON and not attribute.optional: raise ValueError( f"{node.kind}: Attribute {attribute.name!r} is a computed transform, it can't be mandatory" ) + elif attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON: + self.computed_attributes.add_python_attribute(node=node, attribute=attribute) + def get_impacted_macros(self, kind: str, updates: list[str] | None = None) -> list[ComputedAttributeTarget]: if mapping := self._computed_jinja2_attribute_map.get(kind): return mapping.get_targets(updates=updates) diff --git a/backend/infrahub/core/schema/schema_branch_computed.py b/backend/infrahub/core/schema/schema_branch_computed.py new file mode 100644 index 0000000000..22abcdf47d --- /dev/null +++ b/backend/infrahub/core/schema/schema_branch_computed.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from infrahub.core.schema import ( + AttributeSchema, + NodeSchema, +) + + +@dataclass +class PythonDefinition: + kind: str + attribute: AttributeSchema + + @property + def key_name(self) -> str: + return f"{self.kind}_{self.attribute.name}" + + +class ComputedAttributes: + def __init__(self) -> None: + self._computed_python_transform_attribute_map: dict[str, list[AttributeSchema]] = {} + + def add_python_attribute(self, node: NodeSchema, attribute: AttributeSchema) -> None: + if node.kind not in self._computed_python_transform_attribute_map: + self._computed_python_transform_attribute_map[node.kind] = [] + self._computed_python_transform_attribute_map[node.kind].append(attribute) + + def get_kinds_python_attributes(self) -> list[str]: + """Return kinds that have Python attributes defined""" + return list(self._computed_python_transform_attribute_map.keys()) + + @property + def python_attributes_by_transform(self) -> dict[str, list[PythonDefinition]]: + computed_attributes: dict[str, list[PythonDefinition]] = {} + for kind, attributes in self._computed_python_transform_attribute_map.items(): + for attribute in attributes: + if attribute.computed_attribute and attribute.computed_attribute.transform: + if attribute.computed_attribute.transform not in computed_attributes: + computed_attributes[attribute.computed_attribute.transform] = [] + + computed_attributes[attribute.computed_attribute.transform].append( + PythonDefinition(kind=kind, attribute=attribute) + ) + + return computed_attributes diff --git a/backend/infrahub/graphql/mutations/main.py b/backend/infrahub/graphql/mutations/main.py index ec9bcb3c43..46c983af9d 100644 --- a/backend/infrahub/graphql/mutations/main.py +++ b/backend/infrahub/graphql/mutations/main.py @@ -27,7 +27,6 @@ from infrahub.exceptions import ValidationError from infrahub.log import get_log_data, get_logger from infrahub.worker import WORKER_IDENTITY -from infrahub.workflows.catalogue import UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter from .node_getter.by_hfid import MutationNodeGetterByHfid @@ -101,8 +100,6 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectTyp # Reset the time of the query to guarantee that all resolvers executed after this point will account for the changes context.at = Timestamp() - # Get relevant macros based on the current change - # macros = schema_branch.get_impacted_macros(kind=obj.get_kind(), updates=updated_fields) if config.SETTINGS.broker.enable and context.background: log_data = get_log_data() request_id = log_data.get("request_id", "") @@ -120,38 +117,6 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectTyp context.background.add_task(context.service.event.send, event) - # Temporary workaround until there's a proper deployment for Transforms during schema load / repo sync - schema_branch = registry.schema.get_schema_branch(name=context.branch.name) - node_schema = schema_branch.get(obj._schema.kind) - if isinstance(node_schema, NodeSchema): - has_computed_attributes = [ - attribute for attribute in node_schema.attributes if attribute.computed_attribute - ] - if has_computed_attributes: - updated_fields = list(data.keys()) - - await context.service.workflow.submit_workflow( - workflow=UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM, - parameters={ - "branch_name": context.branch.name, - "node_kind": obj.get_kind(), - "object_id": obj.get_id(), - "updated_fields": updated_fields, - }, - ) - - # # Add event - # if macros: - # await context.service.workflow.submit_workflow( - # workflow=PROCESS_COMPUTED_MACRO, - # parameters={ - # "branch_name": context.branch.name, - # "node_kind": obj.get_kind(), - # "object_id": obj.get_id(), - # "updated_fields": updated_fields, - # }, - # ) - return mutation @classmethod diff --git a/backend/infrahub/schema/tasks.py b/backend/infrahub/schema/tasks.py index 1753f48525..0e8d12fb8b 100644 --- a/backend/infrahub/schema/tasks.py +++ b/backend/infrahub/schema/tasks.py @@ -10,7 +10,7 @@ from prefect.events.schemas.automations import EventTrigger, Posture from prefect.logging import get_run_logger -from infrahub.workflows.catalogue import COMPUTED_ATTRIBUTE_SETUP +from infrahub.workflows.catalogue import COMPUTED_ATTRIBUTE_SETUP, COMPUTED_ATTRIBUTE_SETUP_PYTHON from .constants import AUTOMATION_NAME @@ -25,9 +25,7 @@ async def schema_updated_setup() -> None: for item in await client.read_deployments( deployment_filter=DeploymentFilter( name=DeploymentFilterName( - any_=[ - COMPUTED_ATTRIBUTE_SETUP.name, - ] + any_=[COMPUTED_ATTRIBUTE_SETUP.name, COMPUTED_ATTRIBUTE_SETUP_PYTHON.name] ) ) ) @@ -36,6 +34,7 @@ async def schema_updated_setup() -> None: raise ValueError("Unable to find the deployment for PROCESS_COMPUTED_MACRO") deployment_id_computed_attribute_setup = deployments[COMPUTED_ATTRIBUTE_SETUP.name].id + deployment_id_computed_attribute_setup_python = deployments[COMPUTED_ATTRIBUTE_SETUP_PYTHON.name].id schema_update_automation = await client.find_automation(id_or_name=AUTOMATION_NAME) @@ -55,7 +54,13 @@ async def schema_updated_setup() -> None: deployment_id=deployment_id_computed_attribute_setup, parameters={}, job_variables={}, - ) + ), + RunDeployment( + source="selected", + deployment_id=deployment_id_computed_attribute_setup_python, + parameters={}, + job_variables={}, + ), ], ) diff --git a/backend/infrahub/services/__init__.py b/backend/infrahub/services/__init__.py index ba4adffc0f..ec4a195746 100644 --- a/backend/infrahub/services/__init__.py +++ b/backend/infrahub/services/__init__.py @@ -55,7 +55,7 @@ def client(self) -> InfrahubClient: return self._client - def set_client(self, client: InfrahubClient) -> None: + def set_client(self, client: InfrahubClient | None) -> None: self._client = client @property diff --git a/backend/infrahub/workflows/catalogue.py b/backend/infrahub/workflows/catalogue.py index 05ceb74778..607ab6b0b3 100644 --- a/backend/infrahub/workflows/catalogue.py +++ b/backend/infrahub/workflows/catalogue.py @@ -233,6 +233,14 @@ function="computed_attribute_setup", ) +COMPUTED_ATTRIBUTE_SETUP_PYTHON = WorkflowDefinition( + name="computed-attribute-setup-python", + type=WorkflowType.INTERNAL, + module="infrahub.computed_attribute.tasks", + function="computed_attribute_setup_python", +) + + UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM = WorkflowDefinition( name="process_computed_attribute_transform", type=WorkflowType.INTERNAL, @@ -289,6 +297,7 @@ GIT_REPOSITORY_ADD_READ_ONLY, PROCESS_COMPUTED_MACRO, COMPUTED_ATTRIBUTE_SETUP, + COMPUTED_ATTRIBUTE_SETUP_PYTHON, UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM, REQUEST_PROPOSED_CHANGE_DATA_INTEGRITY, SCHEMA_UPDATED_SETUP, diff --git a/backend/tests/fixtures/schemas/infra_simple_01.json b/backend/tests/fixtures/schemas/infra_simple_01.json index 68bfddb0e6..2e01095113 100644 --- a/backend/tests/fixtures/schemas/infra_simple_01.json +++ b/backend/tests/fixtures/schemas/infra_simple_01.json @@ -3,7 +3,7 @@ "nodes": [ { "name": "Device", - "namespace": "Infra", + "namespace": "Test", "default_filter": "name__value", "display_labels": ["name__value"], "attributes": [ @@ -16,7 +16,7 @@ "relationships": [ { "name": "interfaces", - "peer": "InfraInterface", + "peer": "TestInterface", "optional": true, "cardinality": "many", "kind": "Component", @@ -27,7 +27,7 @@ }, { "name": "Interface", - "namespace": "Infra", + "namespace": "Test", "display_labels": ["name__value"], "attributes": [ {"name": "name", "kind": "Text"}, @@ -36,11 +36,11 @@ {"name": "enabled", "kind": "Boolean", "default_value": true} ], "relationships": [ - {"name": "device", "peer": "InfraDevice", "optional": false, "cardinality": "one", "kind": "Parent"}, + {"name": "device", "peer": "TestDevice", "optional": false, "cardinality": "one", "kind": "Parent"}, {"name": "tags", "peer": "BuiltinTag", "optional": true, "cardinality": "many", "kind": "Attribute"}, { "name": "ip_addresses", - "peer": "InfraIPAddress", + "peer": "TestIPAddress", "optional": true, "cardinality": "many", "kind": "Component" @@ -49,7 +49,7 @@ }, { "name": "IPAddress", - "namespace": "Infra", + "namespace": "Test", "default_filter": "address__value", "display_labels": ["address__value"], "attributes": [ @@ -57,7 +57,58 @@ {"name": "description", "kind": "Text", "optional": true} ], "relationships": [ - {"name": "interface", "peer": "InfraInterface", "optional": false, "cardinality": "one", "kind": "Parent"} + {"name": "interface", "peer": "TestInterface", "optional": false, "cardinality": "one", "kind": "Parent"} + ] + }, + { + "name": "Organization", + "namespace": "Testing", + "description": "An organization represent a legal entity, a company.", + "include_in_menu": true, + "label": "Organization", + "icon": "mdi:domain", + "default_filter": "name__value", + "order_by": ["name__value"], + "display_labels": ["label__value"], + "branch": "aware", + "attributes": [ + {"name": "name", "kind": "Text", "unique": true}, + {"name": "label", "kind": "Text", "optional": true}, + {"name": "description", "kind": "Text", "optional": true} + ], + "relationships": [ + { + "name": "tags", + "peer": "BuiltinTag", + "kind": "Attribute", + "optional": true, + "cardinality": "many" + } + ] + }, + { + "name": "Site", + "namespace": "Infra", + "description": "A location represent a physical element site", + "include_in_menu": true, + "icon": "mdi:map-marker-radius-outline", + "label": "Site", + "default_filter": "name__value", + "order_by": ["name__value"], + "display_labels": ["name__value"], + "attributes": [ + {"name": "name", "kind": "Text", "unique": true}, + {"name": "description", "kind": "Text", "optional": true}, + {"name": "type", "kind": "Text"} + ], + "relationships": [ + { + "name": "tags", + "peer": "BuiltinTag", + "kind": "Attribute", + "optional": true, + "cardinality": "many" + } ] } ] diff --git a/backend/tests/fixtures/schemas/infra_w_extensions_01.json b/backend/tests/fixtures/schemas/infra_w_extensions_01.json index 494d0116ec..1c1995de7b 100644 --- a/backend/tests/fixtures/schemas/infra_w_extensions_01.json +++ b/backend/tests/fixtures/schemas/infra_w_extensions_01.json @@ -25,7 +25,7 @@ ] }, { - "kind": "CoreOrganization", + "kind": "TestingOrganization", "relationships": [ {"name": "devices", "peer": "InfraDevice", "kind": "Generic", "cardinality": "many", "optional": true } ] diff --git a/backend/tests/functional/api/__init__.py b/backend/tests/functional/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/functional/api/test_load_schema.py b/backend/tests/functional/api/test_load_schema.py new file mode 100644 index 0000000000..ffd4c14592 --- /dev/null +++ b/backend/tests/functional/api/test_load_schema.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from infrahub_sdk.schema import GenericSchema as SDKGenericSchema + +from infrahub.core.registry import registry +from infrahub.core.schema import core_models +from infrahub.core.utils import count_relationships +from infrahub.database import InfrahubDatabase +from tests.helpers.test_app import TestInfrahubApp + +if TYPE_CHECKING: + from infrahub_sdk import InfrahubClient + + from infrahub.core.branch import Branch + from infrahub.database import InfrahubDatabase + from tests.adapters.message_bus import BusSimulator + from tests.conftest import TestHelper + + +class TestLoadSchemaAPI(TestInfrahubApp): + @pytest.fixture(scope="class") + async def initial_dataset( + self, + db: InfrahubDatabase, + initialize_registry: None, + client: InfrahubClient, + bus_simulator: BusSimulator, + prefect_test_fixture: None, + ) -> None: + pass + + async def test_schema_load_endpoint_idempotent_simple( + self, initial_dataset: str, client: InfrahubClient, helper: TestHelper, db: InfrahubDatabase + ) -> None: + creation = await client.schema.load(schemas=[helper.schema_file("infra_simple_01.json")]) + assert creation.schema_updated + test_device = await client.schema.get(kind="TestDevice") + attributes = {attrib.name: attrib.order_weight for attrib in test_device.attributes} + relationships = {attrib.name: attrib.order_weight for attrib in test_device.relationships} + assert attributes["name"] == 1000 + assert attributes["description"] == 900 + assert attributes["type"] == 3000 + assert relationships["interfaces"] == 450 + assert relationships["tags"] == 7000 + + first_relationship_count = await count_relationships(db=db) + update = await client.schema.load(schemas=[helper.schema_file("infra_simple_01.json")]) + assert not update.schema_updated + updated_relationship_count = await count_relationships(db=db) + + assert first_relationship_count == updated_relationship_count + + async def test_schema_load_endpoint_idempotent_with_generics( + self, initial_dataset: str, client: InfrahubClient, helper: TestHelper, db: InfrahubDatabase + ) -> None: + creation = await client.schema.load(schemas=[helper.schema_file("infra_w_generics_01.json")]) + assert creation.schema_updated + assert creation.schema_updated + first_relationship_count = await count_relationships(db=db) + update = await client.schema.load(schemas=[helper.schema_file("infra_w_generics_01.json")]) + assert not update.schema_updated + updated_relationship_count = await count_relationships(db=db) + + assert first_relationship_count == updated_relationship_count + + all_schemas = await client.schema.all(refresh=True) + generic_schemas = [schema for schema in all_schemas.values() if isinstance(schema, SDKGenericSchema)] + + assert len(generic_schemas) == len(core_models["generics"]) + 1 + + async def test_schema_load_endpoint_valid_with_extensions( + self, + initial_dataset: str, + client: InfrahubClient, + helper: TestHelper, + db: InfrahubDatabase, + default_branch: Branch, + ) -> None: + schema = registry.schema.get_schema_branch(name=default_branch.name) + await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) + simple = await client.schema.load(schemas=[helper.schema_file("infra_simple_01.json")]) + assert not simple.errors + org_schema = registry.schema.get(name="TestingOrganization", branch=default_branch.name) + initial_nbr_relationships = len(org_schema.relationships) + + extended_schema = await client.schema.load(schemas=[helper.schema_file("infra_w_extensions_01.json")]) + assert not extended_schema.errors + assert extended_schema.schema_updated + + org_schema = registry.schema.get(name="TestingOrganization", branch=default_branch.name) + assert len(org_schema.relationships) == initial_nbr_relationships + 1 diff --git a/backend/tests/functional/graphql/__init__.py b/backend/tests/functional/graphql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/functional/graphql/test_api.py b/backend/tests/functional/graphql/test_api.py new file mode 100644 index 0000000000..ca2c3dc47e --- /dev/null +++ b/backend/tests/functional/graphql/test_api.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from infrahub.core.node import Node +from infrahub.core.timestamp import Timestamp +from infrahub.database import InfrahubDatabase +from tests.helpers.schema import CAR_SCHEMA, load_schema +from tests.helpers.test_app import TestInfrahubApp + +if TYPE_CHECKING: + from infrahub_sdk import InfrahubClient + + from infrahub.database import InfrahubDatabase + from tests.adapters.message_bus import BusSimulator + + +class TestPreviousVersions(TestInfrahubApp): + @pytest.fixture(scope="class") + async def initial_dataset( + self, + db: InfrahubDatabase, + initialize_registry: None, + client: InfrahubClient, + bus_simulator: BusSimulator, + prefect_test_fixture: None, + ) -> None: + await load_schema(db, schema=CAR_SCHEMA, update_db=True) + + p1 = await Node.init(db=db, schema="TestingPerson") + await p1.new(db=db, name="John", height=180) + await p1.save(db=db) + p2 = await Node.init(db=db, schema="TestingPerson") + await p2.new(db=db, name="Jane", height=170) + await p2.save(db=db) + + async def test_query_at_previous_schema(self, initial_dataset: str, client: InfrahubClient) -> None: + time_before = Timestamp() + + query = """ + query { + TestingPerson { + edges { + node { + display_label + } + } + } + } + """ + + response = await client.execute_graphql(query=query) + + assert response == { + "TestingPerson": { + "edges": [ + {"node": {"display_label": "John"}}, + {"node": {"display_label": "Jane"}}, + ], + }, + } + + creation = await client.schema.load( + schemas=[ + { + "version": "1.0", + "nodes": [ + { + "name": "Person", + "namespace": "Testing", + "default_filter": "name__value", + "display_labels": ["name__value", "height__value"], + "attributes": [ + {"name": "name", "kind": "Text", "unique": True}, + {"name": "height", "kind": "Number", "optional": True}, + {"name": "description", "kind": "Text", "optional": True}, + {"name": "age", "kind": "Number", "optional": True}, + ], + "inherit_from": ["LineageOwner", "LineageSource"], + }, + ], + } + ] + ) + assert creation.schema_updated + response = await client.execute_graphql(query=query) + assert response == { + "TestingPerson": { + "edges": [ + {"node": {"display_label": "John 180"}}, + {"node": {"display_label": "Jane 170"}}, + ], + }, + } + + # Query before we updated the schema to validate that we can pull the latest schema + response = await client.execute_graphql(query=query, at=time_before.to_string()) + assert response == { + "TestingPerson": { + "edges": [ + {"node": {"display_label": "John"}}, + {"node": {"display_label": "Jane"}}, + ], + }, + } diff --git a/backend/tests/helpers/schema/person.py b/backend/tests/helpers/schema/person.py index 310c7da83c..3656540341 100644 --- a/backend/tests/helpers/schema/person.py +++ b/backend/tests/helpers/schema/person.py @@ -8,6 +8,7 @@ include_in_menu=True, label="Person", default_filter="name__value", + display_labels=["name__value"], attributes=[ AttributeSchema(name="name", kind="Text"), AttributeSchema(name="description", kind="Text", optional=True), diff --git a/backend/tests/integration/git/test_git_repository.py b/backend/tests/integration/git/test_git_repository.py index 091a3b49a2..0476249d25 100644 --- a/backend/tests/integration/git/test_git_repository.py +++ b/backend/tests/integration/git/test_git_repository.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from typing import AsyncGenerator import pytest import yaml @@ -16,6 +17,7 @@ from infrahub.database import InfrahubDatabase from infrahub.git import InfrahubRepository from infrahub.server import app, app_initialization +from infrahub.services import services from infrahub.services.adapters.workflow.local import WorkflowLocalExecution from infrahub.utils import get_models_dir from infrahub.workflows.initialization import setup_task_manager @@ -72,11 +74,14 @@ async def test_client( return InfrahubTestClient(app=app) @pytest.fixture - async def client(self, test_client: InfrahubTestClient, integration_helper): + async def client(self, test_client: InfrahubTestClient, integration_helper) -> AsyncGenerator[InfrahubClient, None]: admin_token = await integration_helper.create_token() config = Config(api_token=admin_token, requester=test_client.async_request) - - return InfrahubClient(config=config) + sdk_client = InfrahubClient(config=config) + original_service_client = services.service._client + services.service.set_client(sdk_client) + yield sdk_client + services.service.set_client(original_service_client) @pytest.fixture(scope="class") async def query_99(self, db: InfrahubDatabase, test_client): diff --git a/backend/tests/integration/schema_lifecycle/test_schema_migration_main.py b/backend/tests/integration/schema_lifecycle/test_schema_migration_main.py index a9c6807f0f..34aeb4e3c4 100644 --- a/backend/tests/integration/schema_lifecycle/test_schema_migration_main.py +++ b/backend/tests/integration/schema_lifecycle/test_schema_migration_main.py @@ -210,6 +210,7 @@ async def test_step03_load(self, db: InfrahubDatabase, client: InfrahubClient, i assert not hasattr(john, "height") # Ensure that we can query the existing node with graphql endpoint + await client.schema.all(refresh=True) api_persons = await client.filters(kind=PERSON_KIND, firstname__value="John") assert len(api_persons) == 1 api_john = api_persons[0] diff --git a/backend/tests/unit/api/test_20_graphql_api.py b/backend/tests/unit/api/test_20_graphql_api.py index e5813b2b6b..3261ba2c48 100644 --- a/backend/tests/unit/api/test_20_graphql_api.py +++ b/backend/tests/unit/api/test_20_graphql_api.py @@ -1,6 +1,5 @@ import pytest -from infrahub.core import registry from infrahub.core.branch import Branch from infrahub.core.initialization import create_branch from infrahub.core.timestamp import Timestamp @@ -208,132 +207,3 @@ async def test_download_schema(db: InfrahubDatabase, client, client_headers): response = client.get("/schema.graphql?branch=notvalid", headers=client_headers) assert response.status_code == 400 - - -async def test_query_at_previous_schema( - db: InfrahubDatabase, - client, - admin_headers, - default_branch: Branch, - authentication_base, - prefect_test_fixture, - workflow_local, - car_person_data, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - time_before = Timestamp() - - query = """ - query { - TestPerson { - edges { - node { - display_label - } - } - } - } - """ - - # Must execute in a with block to execute the startup/shutdown events - with client: - response = client.post( - "/graphql", - json={"query": query}, - headers=admin_headers, - ) - - assert response.status_code == 200 - assert "errors" not in response.json() - assert response.json()["data"] is not None - result = response.json()["data"] - assert result == { - "TestPerson": { - "edges": [ - {"node": {"display_label": "John"}}, - {"node": {"display_label": "Jane"}}, - ], - }, - } - - creation = client.post( - "/api/schema/load", - headers=admin_headers, - json={ - "schemas": [ - { - "version": "1.0", - "nodes": [ - { - "name": "Person", - "namespace": "Test", - "default_filter": "name__value", - "display_labels": ["name__value", "height__value"], - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "height", "kind": "Number", "optional": True}, - ], - "relationships": [ - {"name": "cars", "peer": "TestCar", "cardinality": "many", "direction": "inbound"} - ], - }, - ], - } - ] - }, - ) - data = creation.json() - assert data - assert data.get("schema_updated") is True - assert "diff" in data - assert data["diff"] == { - "added": {}, - "changed": { - "TestPerson": {"added": {}, "changed": {"display_labels": None}, "removed": {}}, - }, - "removed": {}, - } - assert creation.status_code == 200 - - # Do another query to validate that the schema has been updated - response = client.post( - "/graphql", - json={"query": query}, - headers=admin_headers, - ) - - assert response.status_code == 200 - assert "errors" not in response.json() - assert response.json()["data"] is not None - result = response.json()["data"] - assert result == { - "TestPerson": { - "edges": [ - {"node": {"display_label": "John 180"}}, - {"node": {"display_label": "Jane 170"}}, - ], - }, - } - - # Query before we updated the schema to validate that we can pull the latest schema - response = client.post( - f"/graphql?at={time_before.to_string()}", - json={"query": query}, - headers=admin_headers, - ) - - assert "errors" not in response.json() - assert response.status_code == 200 - assert response.json()["data"] is not None - result = response.json()["data"] - assert result == { - "TestPerson": { - "edges": [ - {"node": {"display_label": "John"}}, - {"node": {"display_label": "Jane"}}, - ], - }, - } diff --git a/backend/tests/unit/api/test_40_schema_api.py b/backend/tests/unit/api/test_40_schema_api.py index 18cd52ad63..12ad155e23 100644 --- a/backend/tests/unit/api/test_40_schema_api.py +++ b/backend/tests/unit/api/test_40_schema_api.py @@ -2,17 +2,11 @@ from infrahub.core import registry from infrahub.core.branch import Branch -from infrahub.core.constants import InfrahubKind, SchemaPathType +from infrahub.core.constants import InfrahubKind from infrahub.core.initialization import create_branch from infrahub.core.node import Node -from infrahub.core.path import SchemaPath from infrahub.core.schema import SchemaRoot, core_models -from infrahub.core.utils import count_relationships from infrahub.database import InfrahubDatabase -from infrahub.message_bus.messages.schema_migration_path import ( - SchemaMigrationPathResponse, - SchemaMigrationPathResponseData, -) async def test_schema_read_endpoint_default_branch( @@ -185,43 +179,6 @@ async def test_schema_kind_not_valid( assert response.json()["errors"][0]["message"] == "Unable to find the schema 'NotPresent' in the registry" -async def test_schema_load_endpoint_valid_simple( - db: InfrahubDatabase, - client: TestClient, - admin_headers, - default_branch: Branch, - prefect_test_fixture, - workflow_local, - authentication_base, - helper, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - # Must execute in a with block to execute the startup/shutdown event - with client: - creation = client.post( - "/api/schema/load", headers=admin_headers, json={"schemas": [helper.schema_file("infra_simple_01.json")]} - ) - read = client.get("/api/schema", headers=admin_headers) - - assert creation.json()["schema_updated"] - assert creation.status_code == 200 - assert read.status_code == 200 - nodes = read.json()["nodes"] - device = [node for node in nodes if node["name"] == "Device"] - assert device - device = device[0] - attributes = {attrib["name"]: attrib["order_weight"] for attrib in device["attributes"]} - relationships = {attrib["name"]: attrib["order_weight"] for attrib in device["relationships"]} - assert attributes["name"] == 1000 - assert attributes["description"] == 900 - assert attributes["type"] == 3000 - assert relationships["interfaces"] == 450 - assert relationships["tags"] == 7000 - - async def test_schema_load_permission_failure( db: InfrahubDatabase, client: TestClient, @@ -273,183 +230,6 @@ async def test_schema_load_restricted_namespace( assert response.json()["errors"][0]["message"] == "Restricted namespace 'Internal' used on 'Timestamp'" -async def test_schema_load_endpoint_idempotent_simple( - db: InfrahubDatabase, - client: TestClient, - admin_headers, - default_branch: Branch, - prefect_test_fixture, - workflow_local, - register_core_schema_db, - authentication_base, - helper, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - # Must execute in a with block to execute the startup/shutdown events - with client: - creation = client.post( - "/api/schema/load", headers=admin_headers, json={"schemas": [helper.schema_file("infra_simple_01.json")]} - ) - read = client.get("/api/schema", headers=admin_headers) - - nbr_rels = await count_relationships(db=db) - - assert creation.status_code == 200 - assert read.status_code == 200 - nodes = read.json()["nodes"] - device = [node for node in nodes if node["name"] == "Device"] - assert device - device = device[0] - attributes = {attrib["name"]: attrib["order_weight"] for attrib in device["attributes"]} - relationships = {attrib["name"]: attrib["order_weight"] for attrib in device["relationships"]} - assert attributes["name"] == 1000 - assert attributes["description"] == 900 - assert attributes["type"] == 3000 - assert relationships["interfaces"] == 450 - assert relationships["tags"] == 7000 - - creation = client.post( - "/api/schema/load", headers=admin_headers, json={"schemas": [helper.schema_file("infra_simple_01.json")]} - ) - read = client.get("/api/schema", headers=admin_headers) - - assert creation.status_code == 200 - assert read.status_code == 200 - - assert nbr_rels == await count_relationships(db=db) - - -async def test_schema_load_endpoint_valid_with_generics( - db: InfrahubDatabase, - client: TestClient, - admin_headers, - default_branch: Branch, - prefect_test_fixture, - workflow_local, - register_core_schema_db, - authentication_base, - helper, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - # Must execute in a with block to execute the startup/shutdown events - with client: - response1 = client.post( - "/api/schema/load", - headers=admin_headers, - json={"schemas": [helper.schema_file("infra_w_generics_01.json")]}, - ) - assert response1.status_code == 200 - - response2 = client.get("/api/schema", headers=admin_headers) - assert response2.status_code == 200 - - schema = response2.json() - assert len(schema["generics"]) == len(core_models.get("generics")) + 1 - - -async def test_schema_load_endpoint_idempotent_with_generics( - db: InfrahubDatabase, - client: TestClient, - admin_headers, - default_branch: Branch, - prefect_test_fixture, - workflow_local, - register_core_schema_db, - authentication_base, - helper, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - # Must execute in a with block to execute the startup/shutdown events - with client: - response1 = client.post( - "/api/schema/load", - headers=admin_headers, - json={"schemas": [helper.schema_file("infra_w_generics_01.json")]}, - ) - assert response1.json()["schema_updated"] - assert response1.status_code == 200 - - response2 = client.get("/api/schema", headers=admin_headers) - assert response2.status_code == 200 - - schema = response2.json() - assert len(schema["generics"]) == len(core_models.get("generics")) + 1 - - nbr_rels = await count_relationships(db=db) - - response3 = client.post( - "/api/schema/load", - headers=admin_headers, - json={"schemas": [helper.schema_file("infra_w_generics_01.json")]}, - ) - assert response3.json()["schema_updated"] is False - assert response3.status_code == 200 - - response4 = client.get("/api/schema", headers=admin_headers) - assert response4.status_code == 200 - - nbr_rels_after = await count_relationships(db=db) - assert nbr_rels == nbr_rels_after - - -async def test_schema_load_endpoint_valid_with_extensions( - db: InfrahubDatabase, - client: TestClient, - admin_headers, - rpc_bus, - default_branch: Branch, - prefect_test_fixture, - workflow_local, - authentication_base, - helper, -): - # Load the schema in the database - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db(schema=schema, branch=default_branch, db=db) - - org_schema = registry.schema.get(name="CoreOrganization", branch=default_branch.name) - initial_nbr_relationships = len(org_schema.relationships) - - schema = registry.schema.get_schema_branch(name=default_branch.name) - await registry.schema.load_schema_to_db( - db=db, schema=schema, branch=default_branch, limit=["CoreOrganization", "InfraSite"] - ) - - rpc_bus.response.append( - SchemaMigrationPathResponse( - data=SchemaMigrationPathResponseData( - migration_name="test.test.update", - errors=[], - nbr_migrations_executed=3, - schema_path=SchemaPath(path_type=SchemaPathType.NODE, schema_kind="CoreOrganization"), - ) - ) - ) - - # Must execute in a with block to execute the startup/shutdown events - with client: - response = client.post( - "/api/schema/load", - headers=admin_headers, - json={"schemas": [helper.schema_file("infra_w_extensions_01.json")]}, - ) - - assert response.json()["schema_updated"] - assert response.status_code == 200 - - org_schema = registry.schema.get(name="CoreOrganization", branch=default_branch.name) - assert len(org_schema.relationships) == initial_nbr_relationships + 1 - - async def test_schema_load_endpoint_not_valid_simple_02( db: InfrahubDatabase, client: TestClient, diff --git a/backend/tests/unit/core/schema_manager/test_manager_schema.py b/backend/tests/unit/core/schema_manager/test_manager_schema.py index 5ef988fa09..bcbf8a9a7c 100644 --- a/backend/tests/unit/core/schema_manager/test_manager_schema.py +++ b/backend/tests/unit/core/schema_manager/test_manager_schema.py @@ -31,6 +31,7 @@ from infrahub.core.schema.schema_branch import SchemaBranch from infrahub.database import InfrahubDatabase from infrahub.exceptions import SchemaNotFoundError, ValidationError +from tests.conftest import TestHelper from tests.helpers.schema import CHILD, THING from .conftest import _get_schema_by_kind @@ -1316,28 +1317,29 @@ async def test_validate_default_value_error(schema_all_in_one, default_value_att async def test_schema_branch_load_schema_extension( - db: InfrahubDatabase, default_branch, organization_schema, builtin_schema, helper + db: InfrahubDatabase, default_branch, builtin_schema, helper: TestHelper ): schema = SchemaRoot(**core_models) schema_branch = SchemaBranch(cache={}, name="test") schema_branch.load_schema(schema=schema) schema_branch.load_schema(schema=builtin_schema) - schema_branch.load_schema(schema=organization_schema) + schema_branch.load_schema(schema=SchemaRoot(**helper.schema_file("infra_simple_01.json"))) + schema_branch.process() - org = schema_branch.get(name="CoreOrganization") + org = schema_branch.get(name="TestingOrganization") initial_nbr_relationships = len(org.relationships) schema_branch.load_schema(schema=SchemaRoot(**helper.schema_file("infra_w_extensions_01.json"))) - org = schema_branch.get(name="CoreOrganization") + org = schema_branch.get(name="TestingOrganization") assert len(org.relationships) == initial_nbr_relationships + 1 assert schema_branch.get(name="InfraDevice") # Load it a second time to check if it's idempotent schema_branch.load_schema(schema=SchemaRoot(**helper.schema_file("infra_w_extensions_01.json"))) - org = schema_branch.get(name="CoreOrganization") + org = schema_branch.get(name="TestingOrganization") assert len(org.relationships) == initial_nbr_relationships + 1 assert schema_branch.get(name="InfraDevice")