Skip to content

Commit d7da083

Browse files
authored
add uniqueness constraints for SchemaAttribute/Relationship (#7348)
* add uniqueness constraints for SchemaAttribute/Relationship * add generics back into TestingPerson schema * CLI command to clean up duplicated attrs and rels on schemas on main * add changelog * fix changelog * update docs * test cleanup
1 parent 38b9f96 commit d7da083

File tree

9 files changed

+740
-3
lines changed

9 files changed

+740
-3
lines changed

backend/infrahub/cli/db.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
5656
from .db_commands.check_inheritance import check_inheritance
57+
from .db_commands.clean_duplicate_schema_fields import clean_duplicate_schema_fields
5758
from .patch import patch_app
5859

5960

@@ -200,6 +201,29 @@ async def check_inheritance_cmd(
200201
await dbdriver.close()
201202

202203

204+
@app.command(name="check-duplicate-schema-fields")
205+
async def check_duplicate_schema_fields_cmd(
206+
ctx: typer.Context,
207+
fix: bool = typer.Option(False, help="Fix the duplicate schema fields on the default branch."),
208+
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
209+
) -> None:
210+
"""Check for any duplicate schema attributes or relationships on the default branch"""
211+
logging.getLogger("infrahub").setLevel(logging.WARNING)
212+
logging.getLogger("neo4j").setLevel(logging.ERROR)
213+
logging.getLogger("prefect").setLevel(logging.ERROR)
214+
215+
config.load_and_exit(config_file_name=config_file)
216+
217+
context: CliContext = ctx.obj
218+
dbdriver = await context.init_db(retry=1)
219+
220+
success = await clean_duplicate_schema_fields(db=dbdriver, fix=fix)
221+
if not success:
222+
raise typer.Exit(code=1)
223+
224+
await dbdriver.close()
225+
226+
203227
@app.command(name="update-core-schema")
204228
async def update_core_schema_cmd(
205229
ctx: typer.Context,
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import Any
4+
5+
from rich import print as rprint
6+
from rich.console import Console
7+
from rich.table import Table
8+
9+
from infrahub.cli.constants import FAILED_BADGE, SUCCESS_BADGE
10+
from infrahub.core.query import Query, QueryType
11+
from infrahub.database import InfrahubDatabase
12+
13+
14+
class SchemaFieldType(str, Enum):
15+
ATTRIBUTE = "attribute"
16+
RELATIONSHIP = "relationship"
17+
18+
19+
@dataclass
20+
class SchemaFieldDetails:
21+
schema_kind: str
22+
schema_uuid: str
23+
field_type: SchemaFieldType
24+
field_name: str
25+
26+
27+
class DuplicateSchemaFields(Query):
28+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
29+
query = """
30+
MATCH (root:Root)
31+
LIMIT 1
32+
WITH root.default_branch AS default_branch
33+
MATCH (field:SchemaAttribute|SchemaRelationship)
34+
CALL (default_branch, field) {
35+
MATCH (field)-[is_part_of:IS_PART_OF]->(:Root)
36+
WHERE is_part_of.branch = default_branch
37+
ORDER BY is_part_of.from DESC
38+
RETURN is_part_of
39+
LIMIT 1
40+
}
41+
WITH default_branch, field, CASE
42+
WHEN is_part_of.status = "active" AND is_part_of.to IS NULL THEN is_part_of.from
43+
ELSE NULL
44+
END AS active_from
45+
WHERE active_from IS NOT NULL
46+
WITH default_branch, field, active_from, "SchemaAttribute" IN labels(field) AS is_attribute
47+
CALL (field, default_branch) {
48+
MATCH (field)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
49+
WHERE r1.branch = default_branch AND r2.branch = default_branch
50+
AND r1.status = "active" AND r2.status = "active"
51+
AND r1.to IS NULL AND r2.to IS NULL
52+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
53+
LIMIT 1
54+
RETURN name_value.value AS field_name
55+
}
56+
CALL (field, default_branch) {
57+
MATCH (field)-[r1:IS_RELATED]-(rel:Relationship)-[r2:IS_RELATED]-(peer:SchemaNode|SchemaGeneric)
58+
WHERE rel.name IN ["schema__node__relationships", "schema__node__attributes"]
59+
AND r1.branch = default_branch AND r2.branch = default_branch
60+
AND r1.status = "active" AND r2.status = "active"
61+
AND r1.to IS NULL AND r2.to IS NULL
62+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
63+
LIMIT 1
64+
RETURN peer AS schema_vertex
65+
}
66+
WITH default_branch, field, field_name, is_attribute, active_from, schema_vertex
67+
ORDER BY active_from DESC
68+
WITH default_branch, field_name, is_attribute, schema_vertex, collect(field) AS fields_reverse_chron
69+
WHERE size(fields_reverse_chron) > 1
70+
"""
71+
self.add_to_query(query)
72+
73+
74+
class GetDuplicateSchemaFields(DuplicateSchemaFields):
75+
"""
76+
Get the kind, field type, and field name for any duplicated attributes or relationships on a given schema
77+
on the default branch
78+
"""
79+
80+
name = "get_duplicate_schema_fields"
81+
type = QueryType.READ
82+
insert_return = False
83+
84+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
85+
await super().query_init(db=db, **kwargs)
86+
query = """
87+
CALL (schema_vertex, default_branch) {
88+
MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "namespace"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
89+
WHERE r1.branch = default_branch AND r2.branch = default_branch
90+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
91+
LIMIT 1
92+
RETURN name_value.value AS schema_namespace
93+
}
94+
CALL (schema_vertex, default_branch) {
95+
MATCH (schema_vertex)-[r1:HAS_ATTRIBUTE]->(:Attribute {name: "name"})-[r2:HAS_VALUE]->(name_value:AttributeValue)
96+
WHERE r1.branch = default_branch AND r2.branch = default_branch
97+
ORDER BY r1.from DESC, r1.status ASC, r2.from DESC, r2.status ASC
98+
LIMIT 1
99+
RETURN name_value.value AS schema_name
100+
}
101+
RETURN schema_namespace + schema_name AS schema_kind, schema_vertex.uuid AS schema_uuid, field_name, is_attribute
102+
ORDER BY schema_kind ASC, is_attribute DESC, field_name ASC
103+
"""
104+
self.return_labels = ["schema_kind", "schema_uuid", "field_name", "is_attribute"]
105+
self.add_to_query(query)
106+
107+
def get_schema_field_details(self) -> list[SchemaFieldDetails]:
108+
schema_field_details: list[SchemaFieldDetails] = []
109+
for result in self.results:
110+
schema_kind = result.get_as_type(label="schema_kind", return_type=str)
111+
schema_uuid = result.get_as_type(label="schema_uuid", return_type=str)
112+
field_name = result.get_as_type(label="field_name", return_type=str)
113+
is_attribute = result.get_as_type(label="is_attribute", return_type=bool)
114+
schema_field_details.append(
115+
SchemaFieldDetails(
116+
schema_kind=schema_kind,
117+
schema_uuid=schema_uuid,
118+
field_name=field_name,
119+
field_type=SchemaFieldType.ATTRIBUTE if is_attribute else SchemaFieldType.RELATIONSHIP,
120+
)
121+
)
122+
return schema_field_details
123+
124+
125+
class FixDuplicateSchemaFields(DuplicateSchemaFields):
126+
"""
127+
Fix the duplicate schema fields by hard deleting the earlier duplicate(s)
128+
"""
129+
130+
name = "fix_duplicate_schema_fields"
131+
type = QueryType.WRITE
132+
insert_return = False
133+
134+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None:
135+
await super().query_init(db=db, **kwargs)
136+
query = """
137+
WITH default_branch, tail(fields_reverse_chron) AS fields_to_delete
138+
UNWIND fields_to_delete AS field_to_delete
139+
CALL (field_to_delete, default_branch) {
140+
MATCH (field_to_delete)-[r:IS_PART_OF {branch: default_branch}]-()
141+
DELETE r
142+
WITH field_to_delete
143+
MATCH (field_to_delete)-[:IS_RELATED {branch: default_branch}]-(rel:Relationship)
144+
WITH DISTINCT field_to_delete, rel
145+
MATCH (rel)-[r {branch: default_branch}]-()
146+
DELETE r
147+
WITH field_to_delete, rel
148+
OPTIONAL MATCH (rel)
149+
WHERE NOT exists((rel)--())
150+
DELETE rel
151+
WITH DISTINCT field_to_delete
152+
MATCH (field_to_delete)-[:HAS_ATTRIBUTE {branch: default_branch}]->(attr:Attribute)
153+
MATCH (attr)-[r {branch: default_branch}]-()
154+
DELETE r
155+
WITH field_to_delete, attr
156+
OPTIONAL MATCH (attr)
157+
WHERE NOT exists((attr)--())
158+
DELETE attr
159+
WITH DISTINCT field_to_delete
160+
OPTIONAL MATCH (field_to_delete)
161+
WHERE NOT exists((field_to_delete)--())
162+
DELETE field_to_delete
163+
}
164+
"""
165+
self.add_to_query(query)
166+
167+
168+
def display_duplicate_schema_fields(duplicate_schema_fields: list[SchemaFieldDetails]) -> None:
169+
console = Console()
170+
171+
table = Table(title="Duplicate Schema Fields on Default Branch")
172+
173+
table.add_column("Schema Kind")
174+
table.add_column("Schema UUID")
175+
table.add_column("Field Name")
176+
table.add_column("Field Type")
177+
178+
for duplicate_schema_field in duplicate_schema_fields:
179+
table.add_row(
180+
duplicate_schema_field.schema_kind,
181+
duplicate_schema_field.schema_uuid,
182+
duplicate_schema_field.field_name,
183+
duplicate_schema_field.field_type.value,
184+
)
185+
186+
console.print(table)
187+
188+
189+
async def clean_duplicate_schema_fields(db: InfrahubDatabase, fix: bool = False) -> bool:
190+
"""
191+
Identify any attributes or relationships that are duplicated in a schema on the default branch
192+
If fix is True, runs cypher queries to hard delete the earlier duplicate
193+
"""
194+
195+
duplicate_schema_fields_query = await GetDuplicateSchemaFields.init(db=db)
196+
await duplicate_schema_fields_query.execute(db=db)
197+
duplicate_schema_fields = duplicate_schema_fields_query.get_schema_field_details()
198+
199+
if not duplicate_schema_fields:
200+
rprint(f"{SUCCESS_BADGE} No duplicate schema fields found")
201+
return True
202+
203+
display_duplicate_schema_fields(duplicate_schema_fields)
204+
205+
if not fix:
206+
rprint(f"{FAILED_BADGE} Use the --fix flag to fix the duplicate schema fields")
207+
return False
208+
209+
fix_duplicate_schema_fields_query = await FixDuplicateSchemaFields.init(db=db)
210+
await fix_duplicate_schema_fields_query.execute(db=db)
211+
rprint(f"{SUCCESS_BADGE} Duplicate schema fields deleted from the default branch")
212+
return True

backend/infrahub/core/schema/definitions/internal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class SchemaNode(BaseModel):
180180
attributes: list[SchemaAttribute]
181181
relationships: list[SchemaRelationship]
182182
display_labels: list[str]
183+
uniqueness_constraints: list[list[str]] | None = None
183184

184185
def to_dict(self) -> dict[str, Any]:
185186
return {
@@ -195,6 +196,7 @@ def to_dict(self) -> dict[str, Any]:
195196
],
196197
"relationships": [relationship.to_dict() for relationship in self.relationships],
197198
"display_labels": self.display_labels,
199+
"uniqueness_constraints": self.uniqueness_constraints,
198200
}
199201

200202
def without_duplicates(self, other: SchemaNode) -> SchemaNode:
@@ -465,6 +467,7 @@ def to_dict(self) -> dict[str, Any]:
465467
include_in_menu=False,
466468
default_filter=None,
467469
display_labels=["name__value"],
470+
uniqueness_constraints=[["name__value", "node"]],
468471
attributes=[
469472
SchemaAttribute(
470473
name="id",
@@ -669,6 +672,7 @@ def to_dict(self) -> dict[str, Any]:
669672
include_in_menu=False,
670673
default_filter=None,
671674
display_labels=["name__value"],
675+
uniqueness_constraints=[["name__value", "node"]],
672676
attributes=[
673677
SchemaAttribute(
674678
name="id",

backend/infrahub/core/validators/determiner.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from infrahub.core.schema.relationship_schema import RelationshipSchema
1111
from infrahub.core.schema.schema_branch import SchemaBranch
1212
from infrahub.core.validators import CONSTRAINT_VALIDATOR_MAP
13+
from infrahub.exceptions import SchemaNotFoundError
1314
from infrahub.log import get_logger
1415

1516
if TYPE_CHECKING:
@@ -81,7 +82,17 @@ async def _get_constraints_for_one_schema(self, schema: MainSchemaTypes) -> list
8182

8283
async def _get_all_property_constraints(self) -> list[SchemaUpdateConstraintInfo]:
8384
constraints: list[SchemaUpdateConstraintInfo] = []
84-
for schema in self.schema_branch.get_all().values():
85+
schemas = list(self.schema_branch.get_all(duplicate=False).values())
86+
# added here to check their uniqueness constraints
87+
try:
88+
schemas.append(self.schema_branch.get_node(name="SchemaAttribute", duplicate=False))
89+
except SchemaNotFoundError:
90+
pass
91+
try:
92+
schemas.append(self.schema_branch.get_node(name="SchemaRelationship", duplicate=False))
93+
except SchemaNotFoundError:
94+
pass
95+
for schema in schemas:
8596
constraints.extend(await self._get_property_constraints_for_one_schema(schema=schema))
8697
return constraints
8798

0 commit comments

Comments
 (0)