diff --git a/backend/infrahub/core/constants/__init__.py b/backend/infrahub/core/constants/__init__.py index 02b39677c3..ca3f6fc158 100644 --- a/backend/infrahub/core/constants/__init__.py +++ b/backend/infrahub/core/constants/__init__.py @@ -388,3 +388,4 @@ class AttributeDBNodeType(Flag): OBJECT_TEMPLATE_RELATIONSHIP_NAME = "object_template" OBJECT_TEMPLATE_NAME_ATTR = "template_name" PROFILE_NODE_RELATIONSHIP_IDENTIFIER = "node__profile" +PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER = "template__profile" diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py index ef54ae9ac0..1e43d03324 100644 --- a/backend/infrahub/core/node/__init__.py +++ b/backend/infrahub/core/node/__init__.py @@ -429,9 +429,10 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error for attribute_name in template._attributes: if attribute_name in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]: continue - attr_value = getattr(template, attribute_name).value + attr = getattr(template, attribute_name) + attr_value = attr.value if attr_value is not None: - fields[attribute_name] = {"value": attr_value, "source": template.id} + fields[attribute_name] = {"value": attr_value, "source": attr.source_id or template.id} for relationship_name in template._relationships: relationship_schema = template._schema.get_relationship(name=relationship_name) diff --git a/backend/infrahub/core/node/create.py b/backend/infrahub/core/node/create.py index 103a43b49b..30d2204977 100644 --- a/backend/infrahub/core/node/create.py +++ b/backend/infrahub/core/node/create.py @@ -63,7 +63,8 @@ async def extract_peer_data( for rel in template_peer.get_schema().relationship_names: rel_manager: RelationshipManager = getattr(template_peer, rel) if ( - rel_manager.schema.kind not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT] + rel_manager.schema.kind + not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT, RelationshipKind.PROFILE] or rel_manager.schema.name not in obj_peer_schema.relationship_names ): continue @@ -71,6 +72,10 @@ async def extract_peer_data( if list(await rel_manager.get_peers(db=db)) == [current_template.id]: obj_peer_data[rel] = {"id": parent_obj.id} + if rel_manager.schema.kind == RelationshipKind.PROFILE: + profiles = list(await rel_manager.get_peers(db=db)) + obj_peer_data[rel] = profiles + return obj_peer_data @@ -114,6 +119,12 @@ async def handle_template_relationships( await constraint_runner.check(node=obj_peer, field_filters=list(obj_peer_data)) await obj_peer.save(db=db) + template_profile_ids = await get_profile_ids(db=db, obj=template_relationship_peer) + if template_profile_ids: + node_profiles_applier = NodeProfilesApplier(db=db, branch=branch) + await node_profiles_applier.apply_profiles(node=obj_peer) + await obj_peer.save(db=db) + await handle_template_relationships( db=db, branch=branch, diff --git a/backend/infrahub/core/schema/schema_branch.py b/backend/infrahub/core/schema/schema_branch.py index 4f26e0b09e..addaeaaebe 100644 --- a/backend/infrahub/core/schema/schema_branch.py +++ b/backend/infrahub/core/schema/schema_branch.py @@ -19,6 +19,7 @@ OBJECT_TEMPLATE_NAME_ATTR, OBJECT_TEMPLATE_RELATIONSHIP_NAME, PROFILE_NODE_RELATIONSHIP_IDENTIFIER, + PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER, RESERVED_ATTR_GEN_NAMES, RESERVED_ATTR_REL_NAMES, RESTRICTED_NAMESPACES, @@ -69,6 +70,15 @@ log = get_logger() +profiles_rel_settings: dict[str, Any] = { + "name": "profiles", + "identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER, + "peer": InfrahubKind.PROFILE, + "kind": RelationshipKind.PROFILE, + "cardinality": RelationshipCardinality.MANY, + "branch": BranchSupportType.AWARE, +} + class SchemaBranch: def __init__( @@ -1933,15 +1943,6 @@ def manage_profile_relationships(self) -> None: ): continue - profiles_rel_settings: dict[str, Any] = { - "name": "profiles", - "identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER, - "peer": InfrahubKind.PROFILE, - "kind": RelationshipKind.PROFILE, - "cardinality": RelationshipCardinality.MANY, - "branch": BranchSupportType.AWARE, - } - # Add relationship between node and profile if "profiles" not in node.relationship_names: node_schema = self.get(name=node_name, duplicate=True) @@ -2006,6 +2007,18 @@ def generate_profile_from_node(self, node: NodeSchema) -> ProfileSchema: ) ], ) + if f"Template{node.kind}" in self.all_names: + template = self.get(name=f"Template{node.kind}", duplicate=False) + profile.relationships.append( + RelationshipSchema( + name="related_templates", + identifier=PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER, + peer=template.kind, + kind=RelationshipKind.PROFILE, + cardinality=RelationshipCardinality.MANY, + branch=BranchSupportType.AWARE, + ) + ) for node_attr in node.attributes: if node_attr.read_only or node_attr.optional is False: @@ -2136,6 +2149,14 @@ def add_relationships_to_template(self, node: NodeSchema | GenericSchema) -> Non template_schema.human_friendly_id = [parent_hfid] + template_schema.human_friendly_id template_schema.uniqueness_constraints[0].append(relationship.name) + if getattr(node, "generate_profile", False) and getattr(node, "generate_template", False): + if "profiles" not in [r.name for r in template_schema.relationships]: + settings = dict(profiles_rel_settings) + settings["identifier"] = PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER + template_schema.relationships.append(RelationshipSchema(**settings)) + + self.set(name=template_schema.kind, schema=template_schema) + def generate_object_template_from_node( self, node: NodeSchema | GenericSchema, need_templates: set[NodeSchema | GenericSchema] ) -> TemplateSchema | GenericSchema: diff --git a/backend/infrahub/graphql/mutations/profile.py b/backend/infrahub/graphql/mutations/profile.py index 1fa38a0b68..e30e5b25fc 100644 --- a/backend/infrahub/graphql/mutations/profile.py +++ b/backend/infrahub/graphql/mutations/profile.py @@ -57,6 +57,8 @@ async def _send_profile_refresh_workflows( ) -> None: if not node_ids: related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined] + if hasattr(obj, "related_templates"): + related_nodes.extend(await obj.related_templates.get_relationships(db=db)) # type: ignore[attr-defined] node_ids = [rel.peer_id for rel in related_nodes] if node_ids: await workflow_service.submit_workflow( @@ -79,7 +81,12 @@ def _get_profile_attr_values_map(cls, obj: Node) -> dict[str, Any]: @classmethod async def _get_profile_related_node_ids(cls, db: InfrahubDatabase, obj: Node) -> set[str]: - related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined] + related_nodes = [] + related_nodes.extend(await obj.related_nodes.get_relationships(db=db)) + + if hasattr(obj, "related_templates"): + related_nodes.extend(await obj.related_templates.get_relationships(db=db)) + if related_nodes: related_node_ids = {rel.peer_id for rel in related_nodes} else: diff --git a/backend/infrahub/profiles/node_applier.py b/backend/infrahub/profiles/node_applier.py index f8aebf79e8..f8a9faa974 100644 --- a/backend/infrahub/profiles/node_applier.py +++ b/backend/infrahub/profiles/node_applier.py @@ -3,6 +3,7 @@ from infrahub.core.attribute import BaseAttribute from infrahub.core.branch import Branch from infrahub.core.node import Node +from infrahub.core.schema import TemplateSchema from infrahub.database import InfrahubDatabase from .queries.get_profile_data import GetProfileDataQuery, ProfileData @@ -29,7 +30,12 @@ async def _get_attr_names_for_profiles(self, node: Node) -> list[str]: for attr_schema in node_schema.attributes: attr_name = attr_schema.name node_attr: BaseAttribute = getattr(node, attr_name) - if node_attr.is_from_profile or node_attr.is_default: + is_template = None + if node_attr.source_id: + await node_attr.get_source(db=self.db) + if isinstance(node_attr.source.get_schema(), TemplateSchema): + is_template = True + if node_attr.is_from_profile or node_attr.is_default or is_template: attr_names_for_profiles.append(attr_name) return attr_names_for_profiles diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index d23ded3009..2078de34fe 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -1805,6 +1805,7 @@ async def criticality_schema_root(register_core_models_schema: None) -> SchemaRo "display_labels": ["label__value"], "inherit_from": ["TestGenericCriticality"], "branch": BranchSupportType.AWARE.value, + "generate_template": True, "attributes": [ {"name": "name", "kind": "Text", "unique": True}, {"name": "label", "kind": "Text", "optional": True}, diff --git a/backend/tests/unit/core/profiles/test_node_applier.py b/backend/tests/unit/core/profiles/test_node_applier.py index 3c2c347e48..a89fdaf193 100644 --- a/backend/tests/unit/core/profiles/test_node_applier.py +++ b/backend/tests/unit/core/profiles/test_node_applier.py @@ -26,14 +26,17 @@ async def _validate_node_profile_attrs( ): expected_profile_attrs_by_name = {attr.name: attr for attr in expected_profile_attrs} for attr_name in schema.attribute_names: + # Skip if the attribute is not present on the node (e.g., not set on template) + if not hasattr(updated_node, attr_name): + continue updated_node_attr = getattr(updated_node, attr_name) updated_source = await updated_node_attr.get_source(db=db) - original_node_attr = getattr(original_node, attr_name) + original_node_attr = getattr(original_node, attr_name) if hasattr(original_node, attr_name) else None expected_profile_attr = expected_profile_attrs_by_name.get(attr_name) if expected_profile_attr: assert updated_node_attr.value == expected_profile_attr.value assert updated_source.id == expected_profile_attr.source_uuid - else: + elif original_node_attr is not None: assert updated_node_attr.value == original_node_attr.value assert updated_source is None @@ -214,3 +217,47 @@ async def test_get_many_with_multiple_profiles_same_priority( assert updated_field_names == [] updated_field_names = await node_applier.apply_profiles(node=updated_crit_low) assert updated_field_names == [] + + +async def test_template_profile_application( + db: InfrahubDatabase, + criticality_schema: NodeSchema, + criticality_low: Node, + branch: Branch, +): + profile_schema = registry.schema.get("ProfileTestCriticality", branch=branch) + template_schema = registry.schema.get("TemplateTestCriticality", branch=branch) + + crit_profile_1 = await Node.init(db=db, branch=branch, schema=profile_schema) + await crit_profile_1.new(db=db, profile_name="crit_profile_1", color="green", profile_priority=1001) + await crit_profile_1.save(db=db) + + crit_template = await Node.init(db=db, branch=branch, schema=template_schema) + await crit_template.new(db=db, template_name="crit_template", name="crit_template") + await crit_template.save(db=db) + + # TODO: Fix profile assignment to template + await crit_template.profiles.update(db=db, data=[crit_profile_1]) + + node_applier = NodeProfilesApplier(db=db, branch=branch) + + updated_template_field_names = await node_applier.apply_profiles(node=crit_template) + assert updated_template_field_names == ["color"] + await crit_template.save(db=db) + + node = await NodeManager.get_one(db=db, branch=branch, id=crit_template.id, include_source=True) + assert node.id == crit_template.id + expected_profile_attrs = [ + ExpectedProfileAttr(name="color", value="green", source_uuid=crit_profile_1.id), + ] + await _validate_node_profile_attrs( + db=db, + schema=criticality_schema, + original_node=crit_template, + updated_node=node, + expected_profile_attrs=expected_profile_attrs, + ) + + # make sure field names returned by apply_profiles is idempotent for templates + updated_field_names = await node_applier.apply_profiles(node=crit_template) + assert updated_field_names == [] diff --git a/backend/tests/unit/core/test_node.py b/backend/tests/unit/core/test_node.py index 6e09bb6e87..c219fb1845 100644 --- a/backend/tests/unit/core/test_node.py +++ b/backend/tests/unit/core/test_node.py @@ -817,6 +817,126 @@ async def test_node_create_with_object_template( ) +async def test_node_create_with_object_template_with_profile( + db: InfrahubDatabase, default_branch: Branch, register_core_models_schema: SchemaBranch +): + DUMMY = NodeSchema( + name="Dummy", + namespace="Testing", + generate_template=True, + attributes=[ + AttributeSchema(name="name", kind="Text", unique=True), + ], + ) + + SIMPLE_DEVICE = NodeSchema( + name="Device", + namespace="Testing", + generate_template=True, + generate_profile=True, + attributes=[ + AttributeSchema(name="name", kind="Text", unique=True, order_weight=500), + AttributeSchema(name="manufacturer", kind="Text", order_weight=500), + AttributeSchema(name="height", kind="Number", order_weight=300), + AttributeSchema(name="weight", kind="Number", order_weight=1000), + AttributeSchema( + name="airflow", + kind="Text", + enum=["Front to rear", "Rear to front"], + optional=True, + ), + ], + relationships=[ + RelationshipSchema( + name="dummy", + peer="TestingDummy", + cardinality=RelationshipCardinality.ONE, + order_weight=5000, + optional=True, + ) + ], + ) + registry.schema.set(name=DUMMY.kind, schema=DUMMY, branch=default_branch.name) + registry.schema.set(name=SIMPLE_DEVICE.kind, schema=SIMPLE_DEVICE, branch=default_branch.name) + registry.schema.process_schema_branch(name=default_branch.name) + + template_schema = registry.schema.get(name=f"Template{SIMPLE_DEVICE.kind}", branch=default_branch.name) + node_schema = registry.schema.get(name=SIMPLE_DEVICE.kind, branch=default_branch.name) + + # Validate that the attributes respect the order_weight defined on the original schema + template_weights = { + attr.name: attr.order_weight for attr in template_schema.attributes + template_schema.relationships + } + + assert "name" not in template_weights + assert template_weights["manufacturer"] == 10500 + assert template_weights["dummy"] == 15000 + + profile_schema = registry.schema.get(name=f"Profile{SIMPLE_DEVICE.kind}", branch=default_branch.name) + + profile = await Node.init(db=db, branch=default_branch.name, schema=profile_schema) + await profile.new( + db=db, + profile_name="Airflow Rear to Front", + airflow="Rear to front", + ) + await profile.save(db=db) + + template = await Node.init(db=db, schema=template_schema) + await template.new( + db=db, + template_name="Juniper MX204", + manufacturer="Juniper", + height=1, + weight=8, + ) + await template.save(db=db) + # TODO: Fix profile assignment + await template.profiles.update(db=db, data=[profile]) + template_profiles = await template.profiles.get_peers(db=db) + assert len(template_profiles) == 1 + + from infrahub.profiles.node_applier import NodeProfilesApplier + + node_applier = NodeProfilesApplier(db=db, branch=default_branch) + await node_applier.apply_profiles(node=template) + await template.save(db=db) + assert template.airflow.value == "Rear to front" + assert template.airflow.source_id == profile.id + + device: Node = await Node.init(db=db, schema=node_schema) + await device.new(db=db, name="par-th2-br01", object_template={"id": template.id}) + await device.save(db=db) + + assert device.id + assert device.db_id + assert device.name.value == device.node_changelog.attributes["name"].value == "par-th2-br01" + assert device.node_changelog.attributes["name"].value_update_status == DiffAction.ADDED + assert "source" not in device.node_changelog.attributes["name"].properties + assert device.manufacturer.value == device.node_changelog.attributes["manufacturer"].value == "Juniper" + assert device.node_changelog.attributes["manufacturer"].value_update_status == DiffAction.ADDED + assert ( + device.manufacturer.source_id + == device.node_changelog.attributes["manufacturer"].properties["source"].value + == template.id + ) + assert device.height.value == device.node_changelog.attributes["height"].value == 1 + assert device.node_changelog.attributes["height"].value_update_status == DiffAction.ADDED + assert ( + device.height.source_id == device.node_changelog.attributes["height"].properties["source"].value == template.id + ) + assert device.weight.value == device.node_changelog.attributes["weight"].value == 8 + assert device.node_changelog.attributes["weight"].value_update_status == DiffAction.ADDED + assert ( + device.weight.source_id == device.node_changelog.attributes["weight"].properties["source"].value == template.id + ) + assert device.airflow.value.value == device.node_changelog.attributes["airflow"].value.value == "Rear to front" + assert device.node_changelog.attributes["airflow"].value_update_status == DiffAction.ADDED + assert ( + device.airflow.source_id == device.node_changelog.attributes["airflow"].properties["source"].value == profile.id + ) + + # -------------------------------------------------------------------------- # Update # -------------------------------------------------------------------------- diff --git a/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts b/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts index 84848d699a..1fb3534f0f 100644 --- a/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts +++ b/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts @@ -3,6 +3,7 @@ import { jsonToGraphQLQuery } from "json-to-graphql-query"; import { addAttributesToRequest, addRelationshipsToRequest } from "@/shared/api/graphql/utils"; import { getRelationshipsForForm } from "@/shared/components/form/utils/getRelationshipsForForm"; +import { getSchema } from "@/entities/schema/domain/get-schema"; import type { NodeSchema, ProfileSchema } from "@/entities/schema/types"; export const generateObjectEditFormQuery = ({ @@ -12,6 +13,10 @@ export const generateObjectEditFormQuery = ({ schema: NodeSchema | ProfileSchema; objectId: string; }): string => { + let parentSchema: NodeSchema | ProfileSchema | undefined; + if (schema.kind && schema.kind.includes("Template")) { + parentSchema = getSchema(schema.name).schema; + } const request = { query: { __name: "GetObjectForEditForm", @@ -32,7 +37,8 @@ export const generateObjectEditFormQuery = ({ getRelationshipsForForm(schema.relationships ?? [], true, schema), { withMetadata: true } ), - ...("generate_profile" in schema && schema.generate_profile + ...(("generate_profile" in schema && schema.generate_profile) || + (parentSchema && "generate_profile" in parentSchema && parentSchema.generate_profile) ? { profiles: { edges: { diff --git a/frontend/app/src/entities/nodes/object-template/object-template-form.tsx b/frontend/app/src/entities/nodes/object-template/object-template-form.tsx index 1b5a74b919..964cbf2da6 100644 --- a/frontend/app/src/entities/nodes/object-template/object-template-form.tsx +++ b/frontend/app/src/entities/nodes/object-template/object-template-form.tsx @@ -89,7 +89,15 @@ export default function ObjectTemplateForm({ } if (selectedObjectTemplate !== undefined) { - return ; + return ( + edge?.node) ?? [] + } + objectTemplate={selectedObjectTemplate} + /> + ); } return ( diff --git a/frontend/app/src/shared/components/form/object-form.tsx b/frontend/app/src/shared/components/form/object-form.tsx index 4c727db8c7..29a663072e 100644 --- a/frontend/app/src/shared/components/form/object-form.tsx +++ b/frontend/app/src/shared/components/form/object-form.tsx @@ -63,7 +63,7 @@ export interface ObjectFormProps extends Omit { - const { schema, isNode, isGeneric } = useSchema(kind); + const { schema, isNode, isGeneric, isTemplate } = useSchema(kind); if (!schema) { return ( @@ -148,7 +148,7 @@ const ObjectForm = ({ kind, currentProfiles, ...props }: ObjectFormProps) => { ); } - if (isNode && schema.generate_profile) { + if ((isNode && schema.generate_profile) || isTemplate) { return ; } diff --git a/frontend/app/src/shared/components/form/profiles-selector.tsx b/frontend/app/src/shared/components/form/profiles-selector.tsx index b274eec6a7..06d61f638f 100644 --- a/frontend/app/src/shared/components/form/profiles-selector.tsx +++ b/frontend/app/src/shared/components/form/profiles-selector.tsx @@ -1,7 +1,7 @@ import { gql } from "@apollo/client"; import { Icon } from "@iconify-icon/react"; import { useAtomValue } from "jotai/index"; -import { useEffect, useId } from "react"; +import { useId } from "react"; import useQuery from "@/shared/api/graphql/useQuery"; import { Button } from "@/shared/components/buttons/button-primitive"; @@ -42,12 +42,6 @@ export const ProfilesSelector = ({ }: ProfilesSelectorProps) => { const id = useId(); - useEffect(() => { - if (!value && defaultValue) { - onChange(defaultValue); - } - }, [defaultValue]); - const genericSchemas = useAtomValue(genericSchemasAtom); const profileSchemas = useAtomValue(profileSchemasAtom); @@ -70,7 +64,9 @@ export const ProfilesSelector = ({ const profilesList = kindList .map((profile) => { // Get the profile schema for the current kind - const profileSchema = profileSchemas.find((profileSchema) => profileSchema.name === profile); + const profileSchema = profileSchemas.find( + (profileSchema) => profileSchema.name === profile?.replace("Template", "") + ); // Get attributes for query + form data const attributes = getObjectAttributes({ schema: profileSchema, forProfiles: true }); @@ -114,6 +110,10 @@ export const ProfilesSelector = ({ [] ); + if (!value && defaultValue) { + onChange(profiles.filter((profile) => defaultValue.some((def) => def.id === profile.id))); + } + if (!profiles || profiles.length === 0) return null; const selectedValues = value ?? [];