diff --git a/backend/infrahub/core/constants/__init__.py b/backend/infrahub/core/constants/__init__.py
index 02b39677c3..ca3f6fc158 100644
--- a/backend/infrahub/core/constants/__init__.py
+++ b/backend/infrahub/core/constants/__init__.py
@@ -388,3 +388,4 @@ class AttributeDBNodeType(Flag):
OBJECT_TEMPLATE_RELATIONSHIP_NAME = "object_template"
OBJECT_TEMPLATE_NAME_ATTR = "template_name"
PROFILE_NODE_RELATIONSHIP_IDENTIFIER = "node__profile"
+PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER = "template__profile"
diff --git a/backend/infrahub/core/node/__init__.py b/backend/infrahub/core/node/__init__.py
index ef54ae9ac0..1e43d03324 100644
--- a/backend/infrahub/core/node/__init__.py
+++ b/backend/infrahub/core/node/__init__.py
@@ -429,9 +429,10 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error
for attribute_name in template._attributes:
if attribute_name in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]:
continue
- attr_value = getattr(template, attribute_name).value
+ attr = getattr(template, attribute_name)
+ attr_value = attr.value
if attr_value is not None:
- fields[attribute_name] = {"value": attr_value, "source": template.id}
+ fields[attribute_name] = {"value": attr_value, "source": attr.source_id or template.id}
for relationship_name in template._relationships:
relationship_schema = template._schema.get_relationship(name=relationship_name)
diff --git a/backend/infrahub/core/node/create.py b/backend/infrahub/core/node/create.py
index 103a43b49b..30d2204977 100644
--- a/backend/infrahub/core/node/create.py
+++ b/backend/infrahub/core/node/create.py
@@ -63,7 +63,8 @@ async def extract_peer_data(
for rel in template_peer.get_schema().relationship_names:
rel_manager: RelationshipManager = getattr(template_peer, rel)
if (
- rel_manager.schema.kind not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT]
+ rel_manager.schema.kind
+ not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT, RelationshipKind.PROFILE]
or rel_manager.schema.name not in obj_peer_schema.relationship_names
):
continue
@@ -71,6 +72,10 @@ async def extract_peer_data(
if list(await rel_manager.get_peers(db=db)) == [current_template.id]:
obj_peer_data[rel] = {"id": parent_obj.id}
+ if rel_manager.schema.kind == RelationshipKind.PROFILE:
+ profiles = list(await rel_manager.get_peers(db=db))
+ obj_peer_data[rel] = profiles
+
return obj_peer_data
@@ -114,6 +119,12 @@ async def handle_template_relationships(
await constraint_runner.check(node=obj_peer, field_filters=list(obj_peer_data))
await obj_peer.save(db=db)
+ template_profile_ids = await get_profile_ids(db=db, obj=template_relationship_peer)
+ if template_profile_ids:
+ node_profiles_applier = NodeProfilesApplier(db=db, branch=branch)
+ await node_profiles_applier.apply_profiles(node=obj_peer)
+ await obj_peer.save(db=db)
+
await handle_template_relationships(
db=db,
branch=branch,
diff --git a/backend/infrahub/core/schema/schema_branch.py b/backend/infrahub/core/schema/schema_branch.py
index 4f26e0b09e..addaeaaebe 100644
--- a/backend/infrahub/core/schema/schema_branch.py
+++ b/backend/infrahub/core/schema/schema_branch.py
@@ -19,6 +19,7 @@
OBJECT_TEMPLATE_NAME_ATTR,
OBJECT_TEMPLATE_RELATIONSHIP_NAME,
PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
+ PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER,
RESERVED_ATTR_GEN_NAMES,
RESERVED_ATTR_REL_NAMES,
RESTRICTED_NAMESPACES,
@@ -69,6 +70,15 @@
log = get_logger()
+profiles_rel_settings: dict[str, Any] = {
+ "name": "profiles",
+ "identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
+ "peer": InfrahubKind.PROFILE,
+ "kind": RelationshipKind.PROFILE,
+ "cardinality": RelationshipCardinality.MANY,
+ "branch": BranchSupportType.AWARE,
+}
+
class SchemaBranch:
def __init__(
@@ -1933,15 +1943,6 @@ def manage_profile_relationships(self) -> None:
):
continue
- profiles_rel_settings: dict[str, Any] = {
- "name": "profiles",
- "identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
- "peer": InfrahubKind.PROFILE,
- "kind": RelationshipKind.PROFILE,
- "cardinality": RelationshipCardinality.MANY,
- "branch": BranchSupportType.AWARE,
- }
-
# Add relationship between node and profile
if "profiles" not in node.relationship_names:
node_schema = self.get(name=node_name, duplicate=True)
@@ -2006,6 +2007,18 @@ def generate_profile_from_node(self, node: NodeSchema) -> ProfileSchema:
)
],
)
+ if f"Template{node.kind}" in self.all_names:
+ template = self.get(name=f"Template{node.kind}", duplicate=False)
+ profile.relationships.append(
+ RelationshipSchema(
+ name="related_templates",
+ identifier=PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER,
+ peer=template.kind,
+ kind=RelationshipKind.PROFILE,
+ cardinality=RelationshipCardinality.MANY,
+ branch=BranchSupportType.AWARE,
+ )
+ )
for node_attr in node.attributes:
if node_attr.read_only or node_attr.optional is False:
@@ -2136,6 +2149,14 @@ def add_relationships_to_template(self, node: NodeSchema | GenericSchema) -> Non
template_schema.human_friendly_id = [parent_hfid] + template_schema.human_friendly_id
template_schema.uniqueness_constraints[0].append(relationship.name)
+ if getattr(node, "generate_profile", False) and getattr(node, "generate_template", False):
+ if "profiles" not in [r.name for r in template_schema.relationships]:
+ settings = dict(profiles_rel_settings)
+ settings["identifier"] = PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER
+ template_schema.relationships.append(RelationshipSchema(**settings))
+
+ self.set(name=template_schema.kind, schema=template_schema)
+
def generate_object_template_from_node(
self, node: NodeSchema | GenericSchema, need_templates: set[NodeSchema | GenericSchema]
) -> TemplateSchema | GenericSchema:
diff --git a/backend/infrahub/graphql/mutations/profile.py b/backend/infrahub/graphql/mutations/profile.py
index 1fa38a0b68..e30e5b25fc 100644
--- a/backend/infrahub/graphql/mutations/profile.py
+++ b/backend/infrahub/graphql/mutations/profile.py
@@ -57,6 +57,8 @@ async def _send_profile_refresh_workflows(
) -> None:
if not node_ids:
related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined]
+ if hasattr(obj, "related_templates"):
+ related_nodes.extend(await obj.related_templates.get_relationships(db=db)) # type: ignore[attr-defined]
node_ids = [rel.peer_id for rel in related_nodes]
if node_ids:
await workflow_service.submit_workflow(
@@ -79,7 +81,12 @@ def _get_profile_attr_values_map(cls, obj: Node) -> dict[str, Any]:
@classmethod
async def _get_profile_related_node_ids(cls, db: InfrahubDatabase, obj: Node) -> set[str]:
- related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined]
+ related_nodes = []
+ related_nodes.extend(await obj.related_nodes.get_relationships(db=db))
+
+ if hasattr(obj, "related_templates"):
+ related_nodes.extend(await obj.related_templates.get_relationships(db=db))
+
if related_nodes:
related_node_ids = {rel.peer_id for rel in related_nodes}
else:
diff --git a/backend/infrahub/profiles/node_applier.py b/backend/infrahub/profiles/node_applier.py
index f8aebf79e8..f8a9faa974 100644
--- a/backend/infrahub/profiles/node_applier.py
+++ b/backend/infrahub/profiles/node_applier.py
@@ -3,6 +3,7 @@
from infrahub.core.attribute import BaseAttribute
from infrahub.core.branch import Branch
from infrahub.core.node import Node
+from infrahub.core.schema import TemplateSchema
from infrahub.database import InfrahubDatabase
from .queries.get_profile_data import GetProfileDataQuery, ProfileData
@@ -29,7 +30,12 @@ async def _get_attr_names_for_profiles(self, node: Node) -> list[str]:
for attr_schema in node_schema.attributes:
attr_name = attr_schema.name
node_attr: BaseAttribute = getattr(node, attr_name)
- if node_attr.is_from_profile or node_attr.is_default:
+ is_template = None
+ if node_attr.source_id:
+ await node_attr.get_source(db=self.db)
+ if isinstance(node_attr.source.get_schema(), TemplateSchema):
+ is_template = True
+ if node_attr.is_from_profile or node_attr.is_default or is_template:
attr_names_for_profiles.append(attr_name)
return attr_names_for_profiles
diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py
index d23ded3009..2078de34fe 100644
--- a/backend/tests/unit/conftest.py
+++ b/backend/tests/unit/conftest.py
@@ -1805,6 +1805,7 @@ async def criticality_schema_root(register_core_models_schema: None) -> SchemaRo
"display_labels": ["label__value"],
"inherit_from": ["TestGenericCriticality"],
"branch": BranchSupportType.AWARE.value,
+ "generate_template": True,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "label", "kind": "Text", "optional": True},
diff --git a/backend/tests/unit/core/profiles/test_node_applier.py b/backend/tests/unit/core/profiles/test_node_applier.py
index 3c2c347e48..a89fdaf193 100644
--- a/backend/tests/unit/core/profiles/test_node_applier.py
+++ b/backend/tests/unit/core/profiles/test_node_applier.py
@@ -26,14 +26,17 @@ async def _validate_node_profile_attrs(
):
expected_profile_attrs_by_name = {attr.name: attr for attr in expected_profile_attrs}
for attr_name in schema.attribute_names:
+ # Skip if the attribute is not present on the node (e.g., not set on template)
+ if not hasattr(updated_node, attr_name):
+ continue
updated_node_attr = getattr(updated_node, attr_name)
updated_source = await updated_node_attr.get_source(db=db)
- original_node_attr = getattr(original_node, attr_name)
+ original_node_attr = getattr(original_node, attr_name) if hasattr(original_node, attr_name) else None
expected_profile_attr = expected_profile_attrs_by_name.get(attr_name)
if expected_profile_attr:
assert updated_node_attr.value == expected_profile_attr.value
assert updated_source.id == expected_profile_attr.source_uuid
- else:
+ elif original_node_attr is not None:
assert updated_node_attr.value == original_node_attr.value
assert updated_source is None
@@ -214,3 +217,47 @@ async def test_get_many_with_multiple_profiles_same_priority(
assert updated_field_names == []
updated_field_names = await node_applier.apply_profiles(node=updated_crit_low)
assert updated_field_names == []
+
+
+async def test_template_profile_application(
+ db: InfrahubDatabase,
+ criticality_schema: NodeSchema,
+ criticality_low: Node,
+ branch: Branch,
+):
+ profile_schema = registry.schema.get("ProfileTestCriticality", branch=branch)
+ template_schema = registry.schema.get("TemplateTestCriticality", branch=branch)
+
+ crit_profile_1 = await Node.init(db=db, branch=branch, schema=profile_schema)
+ await crit_profile_1.new(db=db, profile_name="crit_profile_1", color="green", profile_priority=1001)
+ await crit_profile_1.save(db=db)
+
+ crit_template = await Node.init(db=db, branch=branch, schema=template_schema)
+ await crit_template.new(db=db, template_name="crit_template", name="crit_template")
+ await crit_template.save(db=db)
+
+ # TODO: Fix profile assignment to template
+ await crit_template.profiles.update(db=db, data=[crit_profile_1])
+
+ node_applier = NodeProfilesApplier(db=db, branch=branch)
+
+ updated_template_field_names = await node_applier.apply_profiles(node=crit_template)
+ assert updated_template_field_names == ["color"]
+ await crit_template.save(db=db)
+
+ node = await NodeManager.get_one(db=db, branch=branch, id=crit_template.id, include_source=True)
+ assert node.id == crit_template.id
+ expected_profile_attrs = [
+ ExpectedProfileAttr(name="color", value="green", source_uuid=crit_profile_1.id),
+ ]
+ await _validate_node_profile_attrs(
+ db=db,
+ schema=criticality_schema,
+ original_node=crit_template,
+ updated_node=node,
+ expected_profile_attrs=expected_profile_attrs,
+ )
+
+ # make sure field names returned by apply_profiles is idempotent for templates
+ updated_field_names = await node_applier.apply_profiles(node=crit_template)
+ assert updated_field_names == []
diff --git a/backend/tests/unit/core/test_node.py b/backend/tests/unit/core/test_node.py
index 6e09bb6e87..c219fb1845 100644
--- a/backend/tests/unit/core/test_node.py
+++ b/backend/tests/unit/core/test_node.py
@@ -817,6 +817,126 @@ async def test_node_create_with_object_template(
)
+async def test_node_create_with_object_template_with_profile(
+ db: InfrahubDatabase, default_branch: Branch, register_core_models_schema: SchemaBranch
+):
+ DUMMY = NodeSchema(
+ name="Dummy",
+ namespace="Testing",
+ generate_template=True,
+ attributes=[
+ AttributeSchema(name="name", kind="Text", unique=True),
+ ],
+ )
+
+ SIMPLE_DEVICE = NodeSchema(
+ name="Device",
+ namespace="Testing",
+ generate_template=True,
+ generate_profile=True,
+ attributes=[
+ AttributeSchema(name="name", kind="Text", unique=True, order_weight=500),
+ AttributeSchema(name="manufacturer", kind="Text", order_weight=500),
+ AttributeSchema(name="height", kind="Number", order_weight=300),
+ AttributeSchema(name="weight", kind="Number", order_weight=1000),
+ AttributeSchema(
+ name="airflow",
+ kind="Text",
+ enum=["Front to rear", "Rear to front"],
+ optional=True,
+ ),
+ ],
+ relationships=[
+ RelationshipSchema(
+ name="dummy",
+ peer="TestingDummy",
+ cardinality=RelationshipCardinality.ONE,
+ order_weight=5000,
+ optional=True,
+ )
+ ],
+ )
+ registry.schema.set(name=DUMMY.kind, schema=DUMMY, branch=default_branch.name)
+ registry.schema.set(name=SIMPLE_DEVICE.kind, schema=SIMPLE_DEVICE, branch=default_branch.name)
+ registry.schema.process_schema_branch(name=default_branch.name)
+
+ template_schema = registry.schema.get(name=f"Template{SIMPLE_DEVICE.kind}", branch=default_branch.name)
+ node_schema = registry.schema.get(name=SIMPLE_DEVICE.kind, branch=default_branch.name)
+
+ # Validate that the attributes respect the order_weight defined on the original schema
+ template_weights = {
+ attr.name: attr.order_weight for attr in template_schema.attributes + template_schema.relationships
+ }
+
+ assert "name" not in template_weights
+ assert template_weights["manufacturer"] == 10500
+ assert template_weights["dummy"] == 15000
+
+ profile_schema = registry.schema.get(name=f"Profile{SIMPLE_DEVICE.kind}", branch=default_branch.name)
+
+ profile = await Node.init(db=db, branch=default_branch.name, schema=profile_schema)
+ await profile.new(
+ db=db,
+ profile_name="Airflow Rear to Front",
+ airflow="Rear to front",
+ )
+ await profile.save(db=db)
+
+ template = await Node.init(db=db, schema=template_schema)
+ await template.new(
+ db=db,
+ template_name="Juniper MX204",
+ manufacturer="Juniper",
+ height=1,
+ weight=8,
+ )
+ await template.save(db=db)
+ # TODO: Fix profile assignment
+ await template.profiles.update(db=db, data=[profile])
+ template_profiles = await template.profiles.get_peers(db=db)
+ assert len(template_profiles) == 1
+
+ from infrahub.profiles.node_applier import NodeProfilesApplier
+
+ node_applier = NodeProfilesApplier(db=db, branch=default_branch)
+ await node_applier.apply_profiles(node=template)
+ await template.save(db=db)
+ assert template.airflow.value == "Rear to front"
+ assert template.airflow.source_id == profile.id
+
+ device: Node = await Node.init(db=db, schema=node_schema)
+ await device.new(db=db, name="par-th2-br01", object_template={"id": template.id})
+ await device.save(db=db)
+
+ assert device.id
+ assert device.db_id
+ assert device.name.value == device.node_changelog.attributes["name"].value == "par-th2-br01"
+ assert device.node_changelog.attributes["name"].value_update_status == DiffAction.ADDED
+ assert "source" not in device.node_changelog.attributes["name"].properties
+ assert device.manufacturer.value == device.node_changelog.attributes["manufacturer"].value == "Juniper"
+ assert device.node_changelog.attributes["manufacturer"].value_update_status == DiffAction.ADDED
+ assert (
+ device.manufacturer.source_id
+ == device.node_changelog.attributes["manufacturer"].properties["source"].value
+ == template.id
+ )
+ assert device.height.value == device.node_changelog.attributes["height"].value == 1
+ assert device.node_changelog.attributes["height"].value_update_status == DiffAction.ADDED
+ assert (
+ device.height.source_id == device.node_changelog.attributes["height"].properties["source"].value == template.id
+ )
+ assert device.weight.value == device.node_changelog.attributes["weight"].value == 8
+ assert device.node_changelog.attributes["weight"].value_update_status == DiffAction.ADDED
+ assert (
+ device.weight.source_id == device.node_changelog.attributes["weight"].properties["source"].value == template.id
+ )
+ assert device.airflow.value.value == device.node_changelog.attributes["airflow"].value.value == "Rear to front"
+ assert device.node_changelog.attributes["airflow"].value_update_status == DiffAction.ADDED
+ assert (
+ device.airflow.source_id == device.node_changelog.attributes["airflow"].properties["source"].value == profile.id
+ )
+
+
# --------------------------------------------------------------------------
# Update
# --------------------------------------------------------------------------
diff --git a/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts b/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts
index 84848d699a..1fb3534f0f 100644
--- a/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts
+++ b/frontend/app/src/entities/nodes/object-item-edit/generateObjectEditFormQuery.ts
@@ -3,6 +3,7 @@ import { jsonToGraphQLQuery } from "json-to-graphql-query";
import { addAttributesToRequest, addRelationshipsToRequest } from "@/shared/api/graphql/utils";
import { getRelationshipsForForm } from "@/shared/components/form/utils/getRelationshipsForForm";
+import { getSchema } from "@/entities/schema/domain/get-schema";
import type { NodeSchema, ProfileSchema } from "@/entities/schema/types";
export const generateObjectEditFormQuery = ({
@@ -12,6 +13,10 @@ export const generateObjectEditFormQuery = ({
schema: NodeSchema | ProfileSchema;
objectId: string;
}): string => {
+ let parentSchema: NodeSchema | ProfileSchema | undefined;
+ if (schema.kind && schema.kind.includes("Template")) {
+ parentSchema = getSchema(schema.name).schema;
+ }
const request = {
query: {
__name: "GetObjectForEditForm",
@@ -32,7 +37,8 @@ export const generateObjectEditFormQuery = ({
getRelationshipsForForm(schema.relationships ?? [], true, schema),
{ withMetadata: true }
),
- ...("generate_profile" in schema && schema.generate_profile
+ ...(("generate_profile" in schema && schema.generate_profile) ||
+ (parentSchema && "generate_profile" in parentSchema && parentSchema.generate_profile)
? {
profiles: {
edges: {
diff --git a/frontend/app/src/entities/nodes/object-template/object-template-form.tsx b/frontend/app/src/entities/nodes/object-template/object-template-form.tsx
index 1b5a74b919..964cbf2da6 100644
--- a/frontend/app/src/entities/nodes/object-template/object-template-form.tsx
+++ b/frontend/app/src/entities/nodes/object-template/object-template-form.tsx
@@ -89,7 +89,15 @@ export default function ObjectTemplateForm({
}
if (selectedObjectTemplate !== undefined) {
- return ;
+ return (
+ edge?.node) ?? []
+ }
+ objectTemplate={selectedObjectTemplate}
+ />
+ );
}
return (
diff --git a/frontend/app/src/shared/components/form/object-form.tsx b/frontend/app/src/shared/components/form/object-form.tsx
index 4c727db8c7..29a663072e 100644
--- a/frontend/app/src/shared/components/form/object-form.tsx
+++ b/frontend/app/src/shared/components/form/object-form.tsx
@@ -63,7 +63,7 @@ export interface ObjectFormProps extends Omit {
- const { schema, isNode, isGeneric } = useSchema(kind);
+ const { schema, isNode, isGeneric, isTemplate } = useSchema(kind);
if (!schema) {
return (
@@ -148,7 +148,7 @@ const ObjectForm = ({ kind, currentProfiles, ...props }: ObjectFormProps) => {
);
}
- if (isNode && schema.generate_profile) {
+ if ((isNode && schema.generate_profile) || isTemplate) {
return ;
}
diff --git a/frontend/app/src/shared/components/form/profiles-selector.tsx b/frontend/app/src/shared/components/form/profiles-selector.tsx
index b274eec6a7..06d61f638f 100644
--- a/frontend/app/src/shared/components/form/profiles-selector.tsx
+++ b/frontend/app/src/shared/components/form/profiles-selector.tsx
@@ -1,7 +1,7 @@
import { gql } from "@apollo/client";
import { Icon } from "@iconify-icon/react";
import { useAtomValue } from "jotai/index";
-import { useEffect, useId } from "react";
+import { useId } from "react";
import useQuery from "@/shared/api/graphql/useQuery";
import { Button } from "@/shared/components/buttons/button-primitive";
@@ -42,12 +42,6 @@ export const ProfilesSelector = ({
}: ProfilesSelectorProps) => {
const id = useId();
- useEffect(() => {
- if (!value && defaultValue) {
- onChange(defaultValue);
- }
- }, [defaultValue]);
-
const genericSchemas = useAtomValue(genericSchemasAtom);
const profileSchemas = useAtomValue(profileSchemasAtom);
@@ -70,7 +64,9 @@ export const ProfilesSelector = ({
const profilesList = kindList
.map((profile) => {
// Get the profile schema for the current kind
- const profileSchema = profileSchemas.find((profileSchema) => profileSchema.name === profile);
+ const profileSchema = profileSchemas.find(
+ (profileSchema) => profileSchema.name === profile?.replace("Template", "")
+ );
// Get attributes for query + form data
const attributes = getObjectAttributes({ schema: profileSchema, forProfiles: true });
@@ -114,6 +110,10 @@ export const ProfilesSelector = ({
[]
);
+ if (!value && defaultValue) {
+ onChange(profiles.filter((profile) => defaultValue.some((def) => def.id === profile.id)));
+ }
+
if (!profiles || profiles.length === 0) return null;
const selectedValues = value ?? [];