Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/infrahub/core/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,4 @@ class AttributeDBNodeType(Flag):
OBJECT_TEMPLATE_RELATIONSHIP_NAME = "object_template"
OBJECT_TEMPLATE_NAME_ATTR = "template_name"
PROFILE_NODE_RELATIONSHIP_IDENTIFIER = "node__profile"
PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER = "template__profile"
5 changes: 3 additions & 2 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,10 @@ async def handle_object_template(self, fields: dict, db: InfrahubDatabase, error
for attribute_name in template._attributes:
if attribute_name in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]:
continue
attr_value = getattr(template, attribute_name).value
attr = getattr(template, attribute_name)
attr_value = attr.value
if attr_value is not None:
fields[attribute_name] = {"value": attr_value, "source": template.id}
fields[attribute_name] = {"value": attr_value, "source": attr.source_id or template.id}

for relationship_name in template._relationships:
relationship_schema = template._schema.get_relationship(name=relationship_name)
Expand Down
13 changes: 12 additions & 1 deletion backend/infrahub/core/node/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,19 @@ async def extract_peer_data(
for rel in template_peer.get_schema().relationship_names:
rel_manager: RelationshipManager = getattr(template_peer, rel)
if (
rel_manager.schema.kind not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT]
rel_manager.schema.kind
not in [RelationshipKind.COMPONENT, RelationshipKind.PARENT, RelationshipKind.PROFILE]
or rel_manager.schema.name not in obj_peer_schema.relationship_names
):
continue

if list(await rel_manager.get_peers(db=db)) == [current_template.id]:
obj_peer_data[rel] = {"id": parent_obj.id}

if rel_manager.schema.kind == RelationshipKind.PROFILE:
profiles = list(await rel_manager.get_peers(db=db))
obj_peer_data[rel] = profiles

return obj_peer_data


Expand Down Expand Up @@ -114,6 +119,12 @@ async def handle_template_relationships(
await constraint_runner.check(node=obj_peer, field_filters=list(obj_peer_data))
await obj_peer.save(db=db)

template_profile_ids = await get_profile_ids(db=db, obj=template_relationship_peer)
if template_profile_ids:
node_profiles_applier = NodeProfilesApplier(db=db, branch=branch)
await node_profiles_applier.apply_profiles(node=obj_peer)
await obj_peer.save(db=db)

await handle_template_relationships(
db=db,
branch=branch,
Expand Down
39 changes: 30 additions & 9 deletions backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
OBJECT_TEMPLATE_NAME_ATTR,
OBJECT_TEMPLATE_RELATIONSHIP_NAME,
PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER,
RESERVED_ATTR_GEN_NAMES,
RESERVED_ATTR_REL_NAMES,
RESTRICTED_NAMESPACES,
Expand Down Expand Up @@ -69,6 +70,15 @@

log = get_logger()

profiles_rel_settings: dict[str, Any] = {
"name": "profiles",
"identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
"peer": InfrahubKind.PROFILE,
"kind": RelationshipKind.PROFILE,
"cardinality": RelationshipCardinality.MANY,
"branch": BranchSupportType.AWARE,
}


class SchemaBranch:
def __init__(
Expand Down Expand Up @@ -1933,15 +1943,6 @@ def manage_profile_relationships(self) -> None:
):
continue

profiles_rel_settings: dict[str, Any] = {
"name": "profiles",
"identifier": PROFILE_NODE_RELATIONSHIP_IDENTIFIER,
"peer": InfrahubKind.PROFILE,
"kind": RelationshipKind.PROFILE,
"cardinality": RelationshipCardinality.MANY,
"branch": BranchSupportType.AWARE,
}

# Add relationship between node and profile
if "profiles" not in node.relationship_names:
node_schema = self.get(name=node_name, duplicate=True)
Expand Down Expand Up @@ -2006,6 +2007,18 @@ def generate_profile_from_node(self, node: NodeSchema) -> ProfileSchema:
)
],
)
if f"Template{node.kind}" in self.all_names:
template = self.get(name=f"Template{node.kind}", duplicate=False)
profile.relationships.append(
RelationshipSchema(
name="related_templates",
identifier=PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER,
peer=template.kind,
kind=RelationshipKind.PROFILE,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
)
)

for node_attr in node.attributes:
if node_attr.read_only or node_attr.optional is False:
Expand Down Expand Up @@ -2136,6 +2149,14 @@ def add_relationships_to_template(self, node: NodeSchema | GenericSchema) -> Non
template_schema.human_friendly_id = [parent_hfid] + template_schema.human_friendly_id
template_schema.uniqueness_constraints[0].append(relationship.name)

if getattr(node, "generate_profile", False) and getattr(node, "generate_template", False):
if "profiles" not in [r.name for r in template_schema.relationships]:
settings = dict(profiles_rel_settings)
settings["identifier"] = PROFILE_TEMPLATE_RELATIONSHIP_IDENTIFIER
template_schema.relationships.append(RelationshipSchema(**settings))

self.set(name=template_schema.kind, schema=template_schema)

def generate_object_template_from_node(
self, node: NodeSchema | GenericSchema, need_templates: set[NodeSchema | GenericSchema]
) -> TemplateSchema | GenericSchema:
Expand Down
9 changes: 8 additions & 1 deletion backend/infrahub/graphql/mutations/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ async def _send_profile_refresh_workflows(
) -> None:
if not node_ids:
related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined]
if hasattr(obj, "related_templates"):
related_nodes.extend(await obj.related_templates.get_relationships(db=db)) # type: ignore[attr-defined]
node_ids = [rel.peer_id for rel in related_nodes]
if node_ids:
await workflow_service.submit_workflow(
Expand All @@ -79,7 +81,12 @@ def _get_profile_attr_values_map(cls, obj: Node) -> dict[str, Any]:

@classmethod
async def _get_profile_related_node_ids(cls, db: InfrahubDatabase, obj: Node) -> set[str]:
related_nodes = await obj.related_nodes.get_relationships(db=db) # type: ignore[attr-defined]
related_nodes = []
related_nodes.extend(await obj.related_nodes.get_relationships(db=db))

if hasattr(obj, "related_templates"):
related_nodes.extend(await obj.related_templates.get_relationships(db=db))

if related_nodes:
related_node_ids = {rel.peer_id for rel in related_nodes}
else:
Expand Down
8 changes: 7 additions & 1 deletion backend/infrahub/profiles/node_applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from infrahub.core.attribute import BaseAttribute
from infrahub.core.branch import Branch
from infrahub.core.node import Node
from infrahub.core.schema import TemplateSchema
from infrahub.database import InfrahubDatabase

from .queries.get_profile_data import GetProfileDataQuery, ProfileData
Expand All @@ -29,7 +30,12 @@ async def _get_attr_names_for_profiles(self, node: Node) -> list[str]:
for attr_schema in node_schema.attributes:
attr_name = attr_schema.name
node_attr: BaseAttribute = getattr(node, attr_name)
if node_attr.is_from_profile or node_attr.is_default:
is_template = None
if node_attr.source_id:
await node_attr.get_source(db=self.db)
if isinstance(node_attr.source.get_schema(), TemplateSchema):
is_template = True
if node_attr.is_from_profile or node_attr.is_default or is_template:
Comment on lines +33 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that profile values will override template values? if so, that's not what I would expect, but maybe you could explain what the expected behavior is in the PR description

I would think that template values would take precedence over profile values. and I'd think that adding the profiles relationship to a template would be enough and this change to NodeProfilesApplier wouldn't be necessary, but I could definitely be missing something

attr_names_for_profiles.append(attr_name)
return attr_names_for_profiles

Expand Down
1 change: 1 addition & 0 deletions backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,7 @@ async def criticality_schema_root(register_core_models_schema: None) -> SchemaRo
"display_labels": ["label__value"],
"inherit_from": ["TestGenericCriticality"],
"branch": BranchSupportType.AWARE.value,
"generate_template": True,
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "label", "kind": "Text", "optional": True},
Expand Down
51 changes: 49 additions & 2 deletions backend/tests/unit/core/profiles/test_node_applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ async def _validate_node_profile_attrs(
):
expected_profile_attrs_by_name = {attr.name: attr for attr in expected_profile_attrs}
for attr_name in schema.attribute_names:
# Skip if the attribute is not present on the node (e.g., not set on template)
if not hasattr(updated_node, attr_name):
continue
updated_node_attr = getattr(updated_node, attr_name)
updated_source = await updated_node_attr.get_source(db=db)
original_node_attr = getattr(original_node, attr_name)
original_node_attr = getattr(original_node, attr_name) if hasattr(original_node, attr_name) else None
expected_profile_attr = expected_profile_attrs_by_name.get(attr_name)
if expected_profile_attr:
assert updated_node_attr.value == expected_profile_attr.value
assert updated_source.id == expected_profile_attr.source_uuid
else:
elif original_node_attr is not None:
assert updated_node_attr.value == original_node_attr.value
assert updated_source is None

Expand Down Expand Up @@ -214,3 +217,47 @@ async def test_get_many_with_multiple_profiles_same_priority(
assert updated_field_names == []
updated_field_names = await node_applier.apply_profiles(node=updated_crit_low)
assert updated_field_names == []


async def test_template_profile_application(
db: InfrahubDatabase,
criticality_schema: NodeSchema,
criticality_low: Node,
branch: Branch,
):
profile_schema = registry.schema.get("ProfileTestCriticality", branch=branch)
template_schema = registry.schema.get("TemplateTestCriticality", branch=branch)

crit_profile_1 = await Node.init(db=db, branch=branch, schema=profile_schema)
await crit_profile_1.new(db=db, profile_name="crit_profile_1", color="green", profile_priority=1001)
await crit_profile_1.save(db=db)

crit_template = await Node.init(db=db, branch=branch, schema=template_schema)
await crit_template.new(db=db, template_name="crit_template", name="crit_template")
await crit_template.save(db=db)

# TODO: Fix profile assignment to template
await crit_template.profiles.update(db=db, data=[crit_profile_1])

node_applier = NodeProfilesApplier(db=db, branch=branch)

updated_template_field_names = await node_applier.apply_profiles(node=crit_template)
assert updated_template_field_names == ["color"]
await crit_template.save(db=db)

node = await NodeManager.get_one(db=db, branch=branch, id=crit_template.id, include_source=True)
assert node.id == crit_template.id
expected_profile_attrs = [
ExpectedProfileAttr(name="color", value="green", source_uuid=crit_profile_1.id),
]
await _validate_node_profile_attrs(
db=db,
schema=criticality_schema,
original_node=crit_template,
updated_node=node,
expected_profile_attrs=expected_profile_attrs,
)

# make sure field names returned by apply_profiles is idempotent for templates
updated_field_names = await node_applier.apply_profiles(node=crit_template)
assert updated_field_names == []
120 changes: 120 additions & 0 deletions backend/tests/unit/core/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,126 @@ async def test_node_create_with_object_template(
)


async def test_node_create_with_object_template_with_profile(
db: InfrahubDatabase, default_branch: Branch, register_core_models_schema: SchemaBranch
):
DUMMY = NodeSchema(
name="Dummy",
namespace="Testing",
generate_template=True,
attributes=[
AttributeSchema(name="name", kind="Text", unique=True),
],
)

SIMPLE_DEVICE = NodeSchema(
name="Device",
namespace="Testing",
generate_template=True,
generate_profile=True,
attributes=[
AttributeSchema(name="name", kind="Text", unique=True, order_weight=500),
AttributeSchema(name="manufacturer", kind="Text", order_weight=500),
AttributeSchema(name="height", kind="Number", order_weight=300),
AttributeSchema(name="weight", kind="Number", order_weight=1000),
AttributeSchema(
name="airflow",
kind="Text",
enum=["Front to rear", "Rear to front"],
optional=True,
),
],
relationships=[
RelationshipSchema(
name="dummy",
peer="TestingDummy",
cardinality=RelationshipCardinality.ONE,
order_weight=5000,
optional=True,
)
],
)
registry.schema.set(name=DUMMY.kind, schema=DUMMY, branch=default_branch.name)
registry.schema.set(name=SIMPLE_DEVICE.kind, schema=SIMPLE_DEVICE, branch=default_branch.name)
registry.schema.process_schema_branch(name=default_branch.name)

template_schema = registry.schema.get(name=f"Template{SIMPLE_DEVICE.kind}", branch=default_branch.name)
node_schema = registry.schema.get(name=SIMPLE_DEVICE.kind, branch=default_branch.name)

# Validate that the attributes respect the order_weight defined on the original schema
template_weights = {
attr.name: attr.order_weight for attr in template_schema.attributes + template_schema.relationships
}

assert "name" not in template_weights
assert template_weights["manufacturer"] == 10500
assert template_weights["dummy"] == 15000

profile_schema = registry.schema.get(name=f"Profile{SIMPLE_DEVICE.kind}", branch=default_branch.name)

profile = await Node.init(db=db, branch=default_branch.name, schema=profile_schema)
await profile.new(
db=db,
profile_name="Airflow Rear to Front",
airflow="Rear to front",
)
await profile.save(db=db)

template = await Node.init(db=db, schema=template_schema)
await template.new(
db=db,
template_name="Juniper MX204",
manufacturer="Juniper",
height=1,
weight=8,
)
await template.save(db=db)
# TODO: Fix profile assignment
await template.profiles.update(db=db, data=[profile])
template_profiles = await template.profiles.get_peers(db=db)
assert len(template_profiles) == 1

from infrahub.profiles.node_applier import NodeProfilesApplier

node_applier = NodeProfilesApplier(db=db, branch=default_branch)
await node_applier.apply_profiles(node=template)
await template.save(db=db)
assert template.airflow.value == "Rear to front"
assert template.airflow.source_id == profile.id

device: Node = await Node.init(db=db, schema=node_schema)
await device.new(db=db, name="par-th2-br01", object_template={"id": template.id})
await device.save(db=db)

assert device.id
assert device.db_id
assert device.name.value == device.node_changelog.attributes["name"].value == "par-th2-br01"
assert device.node_changelog.attributes["name"].value_update_status == DiffAction.ADDED
assert "source" not in device.node_changelog.attributes["name"].properties
assert device.manufacturer.value == device.node_changelog.attributes["manufacturer"].value == "Juniper"
assert device.node_changelog.attributes["manufacturer"].value_update_status == DiffAction.ADDED
assert (
device.manufacturer.source_id
== device.node_changelog.attributes["manufacturer"].properties["source"].value
== template.id
)
assert device.height.value == device.node_changelog.attributes["height"].value == 1
assert device.node_changelog.attributes["height"].value_update_status == DiffAction.ADDED
assert (
device.height.source_id == device.node_changelog.attributes["height"].properties["source"].value == template.id
)
assert device.weight.value == device.node_changelog.attributes["weight"].value == 8
assert device.node_changelog.attributes["weight"].value_update_status == DiffAction.ADDED
assert (
device.weight.source_id == device.node_changelog.attributes["weight"].properties["source"].value == template.id
)
assert device.airflow.value.value == device.node_changelog.attributes["airflow"].value.value == "Rear to front"
assert device.node_changelog.attributes["airflow"].value_update_status == DiffAction.ADDED
assert (
device.airflow.source_id == device.node_changelog.attributes["airflow"].properties["source"].value == profile.id
)


# --------------------------------------------------------------------------
# Update
# --------------------------------------------------------------------------
Expand Down
Loading