Skip to content
Merged
25 changes: 24 additions & 1 deletion backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,32 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
self.params["branch"] = self.branch.name
self.params["branch_level"] = self.branch.hierarchy_level

if self.branch.is_global or self.branch.is_default:
node_query_match = """
MATCH (n:Node { uuid: $uuid })
OPTIONAL MATCH (n)-[delete_edge:IS_PART_OF {status: "deleted", branch: $branch}]->(:Root)
WHERE delete_edge.from <= $at
WITH n WHERE delete_edge IS NULL
"""
else:
node_filter, node_filter_params = self.branch.get_query_filter_path(at=self.at, variable_name="r")
node_query_match = """
MATCH (n:Node { uuid: $uuid })
CALL {
WITH n
MATCH (n)-[r:IS_PART_OF]->(:Root)
WHERE %(node_filter)s
RETURN r.status = "active" AS is_active
ORDER BY r.from DESC
LIMIT 1
}
WITH n WHERE is_active = TRUE
""" % {"node_filter": node_filter}
self.params.update(node_filter_params)
self.add_to_query(node_query_match)

query = """
MATCH (root:Root)
MATCH (n:Node { uuid: $uuid })
CREATE (n)-[r:IS_PART_OF { branch: $branch, branch_level: $branch_level, status: "deleted", from: $at }]->(root)
"""

Expand Down
138 changes: 78 additions & 60 deletions backend/infrahub/core/query/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,50 +205,20 @@ def get_relationship_properties_dict(self, status: RelationshipStatus) -> dict[s
rel_prop_dict["hierarchy"] = self.schema.hierarchical
return rel_prop_dict


class RelationshipCreateQuery(RelationshipQuery):
name = "relationship_create"

type: QueryType = QueryType.WRITE

def __init__(
self,
destination: Node = None,
destination_id: UUID | None = None,
**kwargs,
):
if not destination and not destination_id:
raise ValueError("Either destination or destination_id must be provided.")

super().__init__(destination=destination, destination_id=destination_id, **kwargs)

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG002
self.params["source_id"] = self.source_id
self.params["destination_id"] = self.destination_id
self.params["name"] = self.schema.identifier
self.params["branch_support"] = self.schema.branch.value

self.params["uuid"] = str(UUIDT())

self.params["branch"] = self.branch.name
self.params["branch_level"] = self.branch.hierarchy_level
self.params["at"] = self.at.to_string()

self.params["is_protected"] = self.rel.is_protected
self.params["is_visible"] = self.rel.is_visible

source_branch = self.source.get_branch_based_on_support_type()
def add_source_match_to_query(self, source_branch: Branch) -> None:
self.params["source_id"] = self.source_id or self.source.get_id()
if source_branch.is_global or source_branch.is_default:
source_query_match = """
MATCH (s:Node { uuid: $source_id })
WHERE NOT exists((s)-[:IS_PART_OF {status: "deleted", branch: $source_branch}]->(:Root))
OPTIONAL MATCH (s)-[delete_edge:IS_PART_OF {status: "deleted", branch: $source_branch}]->(:Root)
WHERE delete_edge.from <= $at
WITH *, s WHERE delete_edge IS NULL
"""
self.params["source_branch"] = source_branch.name
else:
source_filter, source_filter_params = source_branch.get_query_filter_path(
at=self.at, variable_name="r", params_prefix="src_"
)
source_query_match = """
source_filter, source_filter_params = source_branch.get_query_filter_path(
at=self.at, variable_name="r", params_prefix="src_"
)
source_query_match = """
MATCH (s:Node { uuid: $source_id })
CALL {
WITH s
Expand All @@ -258,16 +228,19 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
ORDER BY r.from DESC
LIMIT 1
}
WITH s WHERE s_is_active = TRUE
WITH *, s WHERE s_is_active = TRUE
""" % {"source_filter": source_filter}
self.params.update(source_filter_params)
self.params.update(source_filter_params)
self.add_to_query(source_query_match)

destination_branch = self.destination.get_branch_based_on_support_type()
def add_dest_match_to_query(self, destination_branch: Branch, destination_id: str) -> None:
self.params["destination_id"] = destination_id
if destination_branch.is_global or destination_branch.is_default:
destination_query_match = """
MATCH (d:Node { uuid: $destination_id })
WHERE NOT exists((d)-[:IS_PART_OF {status: "deleted", branch: $destination_branch}]->(:Root))
OPTIONAL MATCH (d)-[delete_edge:IS_PART_OF {status: "deleted", branch: $destination_branch}]->(:Root)
WHERE delete_edge.from <= $at
WITH *, d WHERE delete_edge IS NULL
"""
self.params["destination_branch"] = destination_branch.name
else:
Expand All @@ -284,11 +257,46 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
ORDER BY r.from DESC
LIMIT 1
}
WITH s, d WHERE d_is_active = TRUE
WITH *, d WHERE d_is_active = TRUE
""" % {"destination_filter": destination_filter}
self.params.update(destination_filter_params)
self.add_to_query(destination_query_match)


class RelationshipCreateQuery(RelationshipQuery):
name = "relationship_create"

type: QueryType = QueryType.WRITE

def __init__(
self,
destination: Node = None,
destination_id: UUID | None = None,
**kwargs,
):
if not destination and not destination_id:
raise ValueError("Either destination or destination_id must be provided.")

super().__init__(destination=destination, destination_id=destination_id, **kwargs)

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG002
self.params["name"] = self.schema.identifier
self.params["branch_support"] = self.schema.branch.value

self.params["uuid"] = str(UUIDT())

self.params["branch"] = self.branch.name
self.params["branch_level"] = self.branch.hierarchy_level
self.params["at"] = self.at.to_string()

self.params["is_protected"] = self.rel.is_protected
self.params["is_visible"] = self.rel.is_visible

self.add_source_match_to_query(source_branch=self.source.get_branch_based_on_support_type())
self.add_dest_match_to_query(
destination_branch=self.destination.get_branch_based_on_support_type(),
destination_id=self.destination_id or self.destination.get_id(),
)
self.query_add_all_node_property_match()

self.params["rel_prop"] = self.get_relationship_properties_dict(status=RelationshipStatus.ACTIVE)
Expand Down Expand Up @@ -433,7 +441,6 @@ def __init__(

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG002
self.params["source_id"] = self.source_id
self.params["destination_id"] = self.data.peer_id
self.params["rel_node_id"] = self.data.rel_node_id
self.params["name"] = self.schema.identifier
self.params["branch"] = self.branch.name
Expand All @@ -443,9 +450,10 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
# -----------------------------------------------------------------------
# Match all nodes, including properties
# -----------------------------------------------------------------------

self.add_source_match_to_query(source_branch=self.source.get_branch_based_on_support_type())
self.add_dest_match_to_query(destination_branch=self.branch, destination_id=self.data.peer_id)
query = """
MATCH (s:Node { uuid: $source_id })
MATCH (d:Node { uuid: $destination_id })
MATCH (rl:Relationship { uuid: $rel_node_id })
"""
self.add_to_query(query)
Expand Down Expand Up @@ -497,8 +505,6 @@ def __init__(self, **kwargs):

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG002
rel_filter, rel_params = self.branch.get_query_filter_path(at=self.at, variable_name="edge")
self.params["source_id"] = self.source_id
self.params["destination_id"] = self.destination_id
self.params["rel_id"] = self.rel.id
self.params["branch"] = self.branch.name
self.params["rel_prop"] = self.get_relationship_properties_dict(status=RelationshipStatus.DELETED)
Expand All @@ -509,9 +515,14 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
r1 = f"{arrows.left.start}[r1:{self.rel_type} $rel_prop ]{arrows.left.end}"
r2 = f"{arrows.right.start}[r2:{self.rel_type} $rel_prop ]{arrows.right.end}"

self.add_source_match_to_query(source_branch=self.source.get_branch_based_on_support_type())
self.add_dest_match_to_query(
destination_branch=self.destination.get_branch_based_on_support_type(),
destination_id=self.destination_id or self.destination.get_id(),
)
query = """
MATCH (s:Node { uuid: $source_id })-[:IS_RELATED]-(rl:Relationship {uuid: $rel_id})-[:IS_RELATED]-(d:Node { uuid: $destination_id })
WITH s, rl, d
MATCH (s)-[:IS_RELATED]-(rl:Relationship {uuid: $rel_id})-[:IS_RELATED]-(d)
WITH DISTINCT s, rl, d
LIMIT 1
CREATE (s)%(r1)s(rl)
CREATE (rl)%(r2)s(d)
Expand Down Expand Up @@ -853,8 +864,6 @@ class RelationshipGetQuery(RelationshipQuery):
type: QueryType = QueryType.READ

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG002
self.params["source_id"] = self.source_id
self.params["destination_id"] = self.destination_id
self.params["name"] = self.schema.identifier
self.params["branch"] = self.branch.name

Expand All @@ -868,9 +877,12 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
r1 = f"{arrows.left.start}[r1:{self.rel.rel_type}]{arrows.left.end}"
r2 = f"{arrows.right.start}[r2:{self.rel.rel_type}]{arrows.right.end}"

self.add_source_match_to_query(source_branch=self.source.get_branch_based_on_support_type())
self.add_dest_match_to_query(
destination_branch=self.destination.get_branch_based_on_support_type(),
destination_id=self.destination_id or self.destination.get_id(),
)
query = """
MATCH (s:Node { uuid: $source_id })
MATCH (d:Node { uuid: $destination_id })
MATCH (s)%s(rl:Relationship { name: $name })%s(d)
WHERE %s
""" % (
Expand Down Expand Up @@ -1097,7 +1109,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
CALL {
WITH rl
MATCH (rl)-[active_edge:IS_RELATED]->(n)
WHERE %(active_rel_filter)s AND active_edge.status ="active"
WHERE %(active_rel_filter)s
WITH rl, active_edge, n
ORDER BY %(id_func)s(rl), %(id_func)s(n), active_edge.from DESC
WITH rl, n, head(collect(active_edge)) AS active_edge
WHERE active_edge.status = "active"
CREATE (rl)-[deleted_edge:IS_RELATED $rel_prop]->(n)
SET deleted_edge.hierarchy = active_edge.hierarchy
WITH rl, active_edge, n
Expand All @@ -1113,7 +1129,11 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG

WITH rl
MATCH (rl)<-[active_edge:IS_RELATED]-(n)
WHERE %(active_rel_filter)s AND active_edge.status ="active"
WHERE %(active_rel_filter)s
WITH rl, active_edge, n
ORDER BY %(id_func)s(rl), %(id_func)s(n), active_edge.from DESC
WITH rl, n, head(collect(active_edge)) AS active_edge
WHERE active_edge.status = "active"
CREATE (rl)<-[deleted_edge:IS_RELATED $rel_prop]-(n)
SET deleted_edge.hierarchy = active_edge.hierarchy
WITH rl, active_edge, n
Expand All @@ -1126,9 +1146,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs) -> None: # noqa: ARG
"inbound" as rel_direction
}
RETURN DISTINCT uuid, kind, rel_identifier, rel_direction
""" % {
"active_rel_filter": active_rel_filter,
}
""" % {"active_rel_filter": active_rel_filter, "id_func": db.get_id_function_name()}

self.add_to_query(query)

Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/core/relationship/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ async def delete(self, db: InfrahubDatabase, at: Timestamp | None = None) -> Non
await update_relationships_to(rel_ids_to_update, to=delete_at, db=db)

delete_query = await RelationshipDeleteQuery.init(
db=db, rel=self, source_id=node.id, destination_id=peer.id, branch=branch, at=delete_at
db=db, rel=self, source=node, destination=peer, branch=branch, at=delete_at
)
await delete_query.execute(db=db)

Expand Down
28 changes: 28 additions & 0 deletions backend/tests/helpers/db_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,31 @@ async def validate_node_relationships(node: Node, branch: Branch, db: InfrahubDa
for result in query.results:
print(result)
assert len(result.data) == 1 and result.data[0] == "Edges state is correct"


async def verify_no_duplicate_paths(db: InfrahubDatabase) -> None:
"""Verify that no duplicate paths exist at the database level"""
query = """
MATCH path = (p)-[e]->(q)
WITH
%(id_func)s(p) AS node_id1,
e.branch AS branch,
e.from AS from_time,
type(e) AS edge_type,
%(id_func)s(q) AS node_id2,
path
WITH node_id1, branch, from_time, edge_type, node_id2, size(collect(path)) AS num_paths
WHERE num_paths > 1
RETURN node_id1, branch, from_time, edge_type, node_id2, num_paths
""" % {"id_func": db.get_id_function_name()}
records = await db.execute_query(query=query)
for record in records:
node_id1 = record.get("node_id1")
branch = record.get("branch")
from_time = record.get("from_time")
edge_type = record.get("edge_type")
node_id2 = record.get("node_id2")
num_paths = record.get("num_paths")
raise ValueError(
f"{num_paths} paths ({branch=},{edge_type=},{from_time=}) between nodes '{node_id1}' and '{node_id2}'"
)
24 changes: 1 addition & 23 deletions backend/tests/unit/core/diff/test_diff_and_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,13 @@
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase
from infrahub.dependencies.registry import get_component_registry
from tests.helpers.db_validation import verify_no_duplicate_paths
from tests.unit.conftest import _build_hierarchical_location_data
from tests.unit.core.test_utils import verify_all_linked_edges_deleted

from .get_one_node import get_one_diff_node


async def verify_no_duplicate_paths(db: InfrahubDatabase) -> None:
"""Verify that no duplicate paths exist at the database level"""
query = """
MATCH path = (p)-[e]->(q)
WITH COALESCE(p.uuid, p.value) AS node_id1, e.branch AS branch, e.from AS from_time, type(e) AS edge_type, COALESCE(q.uuid, q.value) AS node_id2, path
WHERE node_id1 IS NOT NULL AND node_id2 IS NOT NULL
WITH node_id1, branch, from_time, edge_type, node_id2, size(collect(path)) AS num_paths
WHERE num_paths > 1
RETURN node_id1, branch, from_time, edge_type, node_id2, num_paths
"""
records = await db.execute_query(query=query)
for record in records:
node_id1 = record.get("node_id1")
branch = record.get("branch")
from_time = record.get("from_time")
edge_type = record.get("edge_type")
node_id2 = record.get("node_id2")
num_paths = record.get("num_paths")
raise ValueError(
f"{num_paths} paths ({branch=},{edge_type=},{from_time=}) between nodes '{node_id1}' and '{node_id2}'"
)


class TestDiffAndMerge:
@pytest.fixture
async def diff_repository(self, db: InfrahubDatabase, default_branch: Branch) -> DiffRepository:
Expand Down
Loading
Loading