Skip to content

Commit 1919dbb

Browse files
authored
IFC-279 Allow to define same peer set contraint (#6526)
This change adds a relationship schema attribute and a constraint validator to check that peers of a relationship have the same set of related nodes on one or more of their relationships.
1 parent 040ca56 commit 1919dbb

File tree

16 files changed

+415
-0
lines changed

16 files changed

+415
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Mapping
5+
6+
from infrahub.core.constants import RelationshipCardinality
7+
from infrahub.exceptions import ValidationError
8+
9+
from .interface import RelationshipManagerConstraintInterface
10+
11+
if TYPE_CHECKING:
12+
from infrahub.core.branch import Branch
13+
from infrahub.core.node import Node
14+
from infrahub.core.schema import MainSchemaTypes, NonGenericSchemaTypes
15+
from infrahub.database import InfrahubDatabase
16+
17+
from ..model import RelationshipManager
18+
19+
20+
@dataclass
21+
class NodeToValidate:
22+
uuid: str
23+
relative_uuids: set[str]
24+
schema: NonGenericSchemaTypes
25+
26+
27+
class RelationshipPeerRelativesConstraint(RelationshipManagerConstraintInterface):
28+
def __init__(self, db: InfrahubDatabase, branch: Branch | None = None):
29+
self.db = db
30+
self.branch = branch
31+
32+
async def _check_relationship_peers_relatives(
33+
self,
34+
relm: RelationshipManager,
35+
node_schema: MainSchemaTypes,
36+
peers: Mapping[str, Node],
37+
relationship_name: str,
38+
) -> None:
39+
"""Validate that all peers of a given `relm` have the same set of relatives (aka peers) for the given `relationship_name`."""
40+
nodes_to_validate: list[NodeToValidate] = []
41+
42+
for peer in peers.values():
43+
peer_schema = peer.get_schema()
44+
peer_relm: RelationshipManager = getattr(peer, relationship_name)
45+
peer_relm_peers = await peer_relm.get_peers(db=self.db)
46+
47+
nodes_to_validate.append(
48+
NodeToValidate(
49+
uuid=peer.id, relative_uuids={n.id for n in peer_relm_peers.values()}, schema=peer_schema
50+
)
51+
)
52+
53+
relative_uuids = nodes_to_validate[0].relative_uuids
54+
for node in nodes_to_validate[1:]:
55+
if node.relative_uuids != relative_uuids:
56+
raise ValidationError(
57+
f"All the elements of the '{relm.name}' relationship on node {node.uuid} ({node_schema.kind}) must have the same set of peers "
58+
f"for their '{node.schema.kind}.{relationship_name}' relationship"
59+
)
60+
61+
async def check(self, relm: RelationshipManager, node_schema: MainSchemaTypes) -> None:
62+
if relm.schema.cardinality != RelationshipCardinality.MANY or not relm.schema.common_relatives:
63+
return
64+
65+
peers = await relm.get_peers(db=self.db)
66+
if not peers:
67+
return
68+
69+
for rel_name in relm.schema.common_relatives:
70+
await self._check_relationship_peers_relatives(
71+
relm=relm, node_schema=node_schema, peers=peers, relationship_name=rel_name
72+
)

backend/infrahub/core/schema/definitions/internal.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,14 @@ def to_dict(self) -> dict[str, Any]:
754754
optional=True,
755755
extra={"update": UpdateSupport.VALIDATE_CONSTRAINT},
756756
),
757+
SchemaAttribute(
758+
name="common_relatives",
759+
kind="List",
760+
internal_kind=str,
761+
optional=True,
762+
description="List of relationship names on the peer schema for which all objects must share the same set of peers.",
763+
extra={"update": UpdateSupport.VALIDATE_CONSTRAINT},
764+
),
757765
SchemaAttribute(
758766
name="order_weight",
759767
kind="Number",

backend/infrahub/core/schema/generated/relationship_schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class GeneratedRelationshipSchema(HashableModel):
7373
description="Defines the maximum objects allowed on the other side of the relationship.",
7474
json_schema_extra={"update": "validate_constraint"},
7575
)
76+
common_relatives: list[str] | None = Field(
77+
default=None,
78+
description="List of relationship names on the peer schema for which all objects must share the same set of peers.",
79+
json_schema_extra={"update": "validate_constraint"},
80+
)
7681
order_weight: int | None = Field(
7782
default=None,
7883
description="Number used to order the relationship in the frontend (table and view). Lowest value will be ordered first.",

backend/infrahub/core/schema/schema_branch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,13 @@ def validate_kinds(self) -> None:
997997
raise ValueError(
998998
f"{node.kind}: Relationship {rel.name!r} is referring an invalid peer {rel.peer!r}"
999999
) from None
1000+
if rel.common_relatives:
1001+
peer_schema = self.get(name=rel.peer, duplicate=False)
1002+
for common_relatives_rel_name in rel.common_relatives:
1003+
if common_relatives_rel_name not in peer_schema.relationship_names:
1004+
raise ValueError(
1005+
f"{node.kind}: Relationship {rel.name!r} set 'common_relatives' with invalid relationship from '{rel.peer}'"
1006+
) from None
10001007

10011008
def validate_attribute_parameters(self) -> None:
10021009
for name in self.generics.keys():

backend/infrahub/dependencies/builder/constraint/grouped/node_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..node.grouped_uniqueness import NodeGroupedUniquenessConstraintDependency
55
from ..relationship_manager.count import RelationshipCountConstraintDependency
66
from ..relationship_manager.peer_kind import RelationshipPeerKindConstraintDependency
7+
from ..relationship_manager.peer_relatives import RelationshipPeerRelativesConstraintDependency
78
from ..relationship_manager.profiles_kind import RelationshipProfilesKindConstraintDependency
89

910

@@ -18,5 +19,6 @@ def build(cls, context: DependencyBuilderContext) -> NodeConstraintRunner:
1819
RelationshipPeerKindConstraintDependency.build(context=context),
1920
RelationshipCountConstraintDependency.build(context=context),
2021
RelationshipProfilesKindConstraintDependency.build(context=context),
22+
RelationshipPeerRelativesConstraintDependency.build(context=context),
2123
],
2224
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from infrahub.core.relationship.constraints.peer_relatives import RelationshipPeerRelativesConstraint
2+
from infrahub.dependencies.interface import DependencyBuilder, DependencyBuilderContext
3+
4+
5+
class RelationshipPeerRelativesConstraintDependency(DependencyBuilder[RelationshipPeerRelativesConstraint]):
6+
@classmethod
7+
def build(cls, context: DependencyBuilderContext) -> RelationshipPeerRelativesConstraint:
8+
return RelationshipPeerRelativesConstraint(db=context.db, branch=context.branch)

backend/infrahub/dependencies/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .builder.constraint.node.uniqueness import NodeAttributeUniquenessConstraintDependency
44
from .builder.constraint.relationship_manager.count import RelationshipCountConstraintDependency
55
from .builder.constraint.relationship_manager.peer_kind import RelationshipPeerKindConstraintDependency
6+
from .builder.constraint.relationship_manager.peer_relatives import RelationshipPeerRelativesConstraintDependency
67
from .builder.constraint.relationship_manager.profiles_kind import RelationshipProfilesKindConstraintDependency
78
from .builder.constraint.schema.aggregated import AggregatedSchemaConstraintsDependency
89
from .builder.constraint.schema.attribute_regex import SchemaAttributeRegexConstraintDependency
@@ -37,6 +38,7 @@ def build_component_registry() -> ComponentDependencyRegistry:
3738
component_registry.track_dependency(RelationshipCountConstraintDependency)
3839
component_registry.track_dependency(RelationshipProfilesKindConstraintDependency)
3940
component_registry.track_dependency(RelationshipPeerKindConstraintDependency)
41+
component_registry.track_dependency(RelationshipPeerRelativesConstraintDependency)
4042
component_registry.track_dependency(NodeConstraintRunnerDependency)
4143
component_registry.track_dependency(NodeDeleteValidatorDependency)
4244
component_registry.track_dependency(IpamKindsGetterDependency)

backend/tests/constants/kind.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DEVICE = "TestingDevice"
77
INTERFACE = "TestingInterface"
88
INTERFACE_HOLDER = "TestingInterfaceHolder"
9+
LAG_INTERFACE = "TestingLinkAggegrationInterface"
910
LOCATION = "TestingLocation"
1011
MANUFACTURER = "TestingManufacturer"
1112
PERSON = "TestingPerson"

backend/tests/functional/constraints/__init__.py

Whitespace-only changes.
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from typing import TYPE_CHECKING
5+
6+
import pytest
7+
from infrahub_sdk.exceptions import GraphQLError
8+
9+
from infrahub.core.node import Node
10+
from tests.constants import TestKind
11+
from tests.helpers.schema import DEVICE_SCHEMA, load_schema
12+
from tests.helpers.schema.device import LAG_INTERFACE
13+
from tests.helpers.test_app import TestInfrahubApp
14+
15+
if TYPE_CHECKING:
16+
from infrahub_sdk import InfrahubClient
17+
18+
from infrahub.core.branch import Branch
19+
from infrahub.core.schema import SchemaRoot
20+
from infrahub.core.schema.schema_branch import SchemaBranch
21+
from infrahub.database import InfrahubDatabase
22+
23+
24+
class TestPeerRelativesConstraint(TestInfrahubApp):
25+
@pytest.fixture(scope="class", autouse=True)
26+
def schema(self, default_branch: Branch, register_internal_schema: SchemaBranch) -> SchemaRoot:
27+
schema_with_lag = copy.deepcopy(DEVICE_SCHEMA)
28+
schema_with_lag.nodes[0].generate_template = False
29+
schema_with_lag.nodes.append(LAG_INTERFACE)
30+
return schema_with_lag
31+
32+
@pytest.fixture(scope="class")
33+
async def data(
34+
self,
35+
db: InfrahubDatabase,
36+
initialize_registry: None,
37+
client: InfrahubClient,
38+
default_branch: Branch,
39+
schema: SchemaRoot,
40+
) -> dict[str, Node]:
41+
await load_schema(db, schema=schema, update_db=True)
42+
43+
device_1 = await Node.init(db=db, schema=TestKind.DEVICE)
44+
await device_1.new(db=db, name="Foo", manufacturer="Foo Inc.", weight=10, airflow="Front to rear")
45+
46+
interfaces_1: list[Node] = []
47+
interfaces_1_ids: list[str] = []
48+
for if_name in ["et-0/0/0", "et-0/0/1", "et-0/0/2", "et-0/0/3"]:
49+
interface = await Node.init(db=db, schema=TestKind.PHYSICAL_INTERFACE)
50+
await interface.new(db=db, name=if_name, phys_type="QSFP28 (100GE)", device=device_1)
51+
await interface.save(db=db)
52+
interfaces_1.append(interface)
53+
interfaces_1_ids.append(interface.id)
54+
55+
await device_1.interfaces.update(db=db, data=interfaces_1) # type: ignore[attr-defined]
56+
await device_1.save(db=db)
57+
58+
device_2 = await Node.init(db=db, schema=TestKind.DEVICE)
59+
await device_2.new(db=db, name="Bar", manufacturer="Bar Inc.", weight=10, airflow="Front to rear")
60+
61+
interfaces_2: list[Node] = []
62+
interfaces_2_ids: list[str] = []
63+
for if_name in ["et-0/0/0", "et-0/0/1", "et-0/0/2", "et-0/0/3"]:
64+
interface = await Node.init(db=db, schema=TestKind.PHYSICAL_INTERFACE)
65+
await interface.new(db=db, name=if_name, phys_type="QSFP28 (100GE)", device=device_2)
66+
await interface.save(db=db)
67+
interfaces_2.append(interface)
68+
interfaces_2_ids.append(interface.id)
69+
70+
await device_2.interfaces.update(db=db, data=interfaces_2) # type: ignore[attr-defined]
71+
await device_2.save(db=db)
72+
73+
return {"device_1": device_1, "device_2": device_2}
74+
75+
async def test_create_lag_main(
76+
self, db: InfrahubDatabase, data: dict[str, Node], client: InfrahubClient, default_branch: Branch
77+
) -> None:
78+
device = await client.get(kind=TestKind.DEVICE, id=data["device_1"].id, branch=default_branch.name)
79+
await device.interfaces.fetch()
80+
81+
lag = await client.create(
82+
kind=TestKind.LAG_INTERFACE,
83+
name="ae0",
84+
device=device,
85+
members=[i.peer for i in device.interfaces], # type: ignore[attr-defined]
86+
branch=default_branch.name,
87+
)
88+
await lag.save()
89+
90+
assert len(lag.members.peers) == 4
91+
assert sorted(lag.members.peer_ids) == sorted([i.id for i in device.interfaces])
92+
93+
async def test_create_incorrect_lag_main(
94+
self, db: InfrahubDatabase, data: dict[str, Node], client: InfrahubClient, default_branch: Branch
95+
) -> None:
96+
device_1 = await client.get(kind=TestKind.DEVICE, id=data["device_1"].id, branch=default_branch.name)
97+
await device_1.interfaces.fetch()
98+
device_2 = await client.get(kind=TestKind.DEVICE, id=data["device_2"].id, branch=default_branch.name)
99+
await device_2.interfaces.fetch()
100+
101+
lag = await client.create(
102+
kind=TestKind.LAG_INTERFACE,
103+
name="ae1",
104+
device=device_1,
105+
members=[i.peer for i in list(device_1.interfaces)[:-1] + list(device_2.interfaces)], # type: ignore[attr-defined]
106+
branch=default_branch.name,
107+
)
108+
109+
with pytest.raises(GraphQLError) as exc:
110+
await lag.save()
111+
assert (
112+
"must have the same set of peers for their 'TestingPhysicalInterface.device' relationship"
113+
in exc.value.errors[0]["message"]
114+
)
115+
116+
async def test_create_lag_branch(
117+
self, db: InfrahubDatabase, data: dict[str, Node], client: InfrahubClient, default_branch: Branch
118+
) -> None:
119+
branch = await client.branch.create(branch_name="test-lag")
120+
121+
device = await client.get(kind=TestKind.DEVICE, id=data["device_2"].id, branch=branch.name)
122+
await device.interfaces.fetch()
123+
124+
lag = await client.create(
125+
kind=TestKind.LAG_INTERFACE,
126+
name="ae0",
127+
device=device,
128+
members=[i.peer for i in device.interfaces], # type: ignore[attr-defined]
129+
branch=branch.name,
130+
)
131+
await lag.save()
132+
133+
assert len(lag.members.peers) == 4
134+
assert sorted(lag.members.peer_ids) == sorted([i.id for i in device.interfaces])
135+
136+
async def test_create_incorrect_lag_branch(
137+
self, db: InfrahubDatabase, data: dict[str, Node], client: InfrahubClient, default_branch: Branch
138+
) -> None:
139+
branch = await client.branch.create(branch_name="test-lag-incorrect")
140+
141+
device_1 = await client.get(kind=TestKind.DEVICE, id=data["device_1"].id, branch=branch.name)
142+
await device_1.interfaces.fetch()
143+
device_2 = await client.get(kind=TestKind.DEVICE, id=data["device_2"].id, branch=branch.name)
144+
await device_2.interfaces.fetch()
145+
146+
lag = await client.create(
147+
kind=TestKind.LAG_INTERFACE,
148+
name="ae1",
149+
device=device_2,
150+
members=[i.peer for i in list(device_1.interfaces)[:-1] + list(device_2.interfaces)[:-1]], # type: ignore[attr-defined]
151+
branch=branch.name,
152+
)
153+
154+
with pytest.raises(GraphQLError) as exc:
155+
await lag.save()
156+
assert (
157+
"must have the same set of peers for their 'TestingPhysicalInterface.device' relationship"
158+
in exc.value.errors[0]["message"]
159+
)

0 commit comments

Comments
 (0)