Skip to content

Commit 81d2a3b

Browse files
authored
Add support for converting object type (#6343)
1 parent 55ad3bc commit 81d2a3b

File tree

19 files changed

+1267
-190
lines changed

19 files changed

+1267
-190
lines changed

backend/infrahub/core/convert_object_type/__init__.py

Whitespace-only changes.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Any
2+
3+
from pydantic import BaseModel
4+
5+
from infrahub.core.attribute import BaseAttribute
6+
from infrahub.core.branch import Branch
7+
from infrahub.core.constants import RelationshipCardinality
8+
from infrahub.core.manager import NodeManager
9+
from infrahub.core.node import Node
10+
from infrahub.core.node.create import create_node
11+
from infrahub.core.query.relationship import GetAllPeersIds
12+
from infrahub.core.relationship import RelationshipManager
13+
from infrahub.core.schema import NodeSchema
14+
from infrahub.database import InfrahubDatabase
15+
16+
17+
class InputDataForDestField(BaseModel): # Only one of these fields can be not None
18+
attribute_value: Any | None = None
19+
peer_id: str | None = None
20+
peers_ids: list[str] | None = None
21+
22+
@property
23+
def value(self) -> Any:
24+
fields = [self.attribute_value, self.peer_id, self.peers_ids]
25+
set_fields = [f for f in fields if f is not None]
26+
if len(set_fields) != 1:
27+
raise ValueError("Exactly one of attribute_value, peer_id, or peers_ids must be set")
28+
return set_fields[0]
29+
30+
31+
class InputForDestField(BaseModel): # Only one of these fields can be not None
32+
source_field: str | None = None
33+
data: InputDataForDestField | None = None
34+
35+
@property
36+
def value(self) -> Any:
37+
if self.source_field is not None and self.data is not None:
38+
raise ValueError("Only one of source_field or data can be set")
39+
if self.source_field is None and self.data is None:
40+
raise ValueError("Either source_field or data must be set")
41+
return self.source_field if self.source_field is not None else self.data
42+
43+
44+
async def get_out_rels_peers_ids(node: Node, db: InfrahubDatabase) -> list[str]:
45+
all_peers: list[Node] = []
46+
for name in node._relationships:
47+
relm: RelationshipManager = getattr(node, name)
48+
peers = await relm.get_peers(db=db)
49+
all_peers.extend(peers.values())
50+
return [peer.id for peer in all_peers]
51+
52+
53+
async def build_data_new_node(db: InfrahubDatabase, mapping: dict[str, InputForDestField], node: Node) -> dict:
54+
"""Value of a given field on the target kind to convert is either an input source attribute/relationship of the source node,
55+
or a raw value."""
56+
57+
data = {}
58+
for dest_field_name, input_for_dest_field in mapping.items():
59+
value = input_for_dest_field.value
60+
if isinstance(value, str): # source_field
61+
item = getattr(node, value)
62+
if isinstance(item, BaseAttribute):
63+
data[dest_field_name] = item.value
64+
elif isinstance(item, RelationshipManager):
65+
if item.schema.cardinality == RelationshipCardinality.ONE:
66+
peer = await item.get_peer(db=db)
67+
if peer is not None:
68+
data[dest_field_name] = {"id": peer.id}
69+
# else, relationship is optional, and if the target relationship is mandatory an error will be raised during creation
70+
elif item.schema.cardinality == RelationshipCardinality.MANY:
71+
data[dest_field_name] = [{"id": peer.id} for _, peer in (await item.get_peers(db=db)).items()]
72+
else:
73+
raise ValueError(f"Unknown cardinality {item.schema.cardinality=}")
74+
else: # user input data
75+
data[dest_field_name] = value.value
76+
return data
77+
78+
79+
async def get_unidirectional_rels_peers_ids(node: Node, branch: Branch, db: InfrahubDatabase) -> list[str]:
80+
"""
81+
Returns peers ids of nodes connected to input `node` through an incoming unidirectional relationship.
82+
"""
83+
84+
out_rels_identifier = [rel.identifier for rel in node.get_schema().relationships]
85+
query = await GetAllPeersIds.init(db=db, node_id=node.id, branch=branch, exclude_identifiers=out_rels_identifier)
86+
await query.execute(db=db)
87+
return query.get_peers_uuids()
88+
89+
90+
async def convert_object_type(
91+
node: Node, target_schema: NodeSchema, mapping: dict[str, InputForDestField], branch: Branch, db: InfrahubDatabase
92+
) -> Node:
93+
"""Delete the node and return the new created one. If creation fails, the node is not deleted, and raise an error.
94+
An extra check is performed on input node peers relationships to make sure they are still valid."""
95+
96+
node_schema = node.get_schema()
97+
if not isinstance(node_schema, NodeSchema):
98+
raise ValueError(f"Only a node with a NodeSchema can be converted, got {type(node_schema)}")
99+
100+
async with db.start_transaction() as dbt: # noqa: PLR1702
101+
deleted_node_out_rels_peer_ids = await get_out_rels_peers_ids(node=node, db=dbt)
102+
deleted_node_unidir_rels_peer_ids = await get_unidirectional_rels_peers_ids(node=node, db=dbt, branch=branch)
103+
104+
deleted_nodes = await NodeManager.delete(db=dbt, branch=branch, nodes=[node], cascade_delete=False)
105+
if len(deleted_nodes) != 1:
106+
raise ValueError(f"Deleted {len(deleted_nodes)} nodes instead of 1")
107+
108+
data_new_node = await build_data_new_node(dbt, mapping, node)
109+
node_created = await create_node(
110+
data=data_new_node,
111+
db=dbt,
112+
branch=branch,
113+
schema=target_schema,
114+
)
115+
116+
# Make sure relationships with constraints are not broken by retrieving them
117+
peers_ids = deleted_node_out_rels_peer_ids + deleted_node_unidir_rels_peer_ids
118+
peers = await NodeManager.get_many(ids=peers_ids, db=dbt, prefetch_relationships=True, branch=branch)
119+
for peer in peers.values():
120+
peer.validate_relationships()
121+
122+
return node_created
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pydantic import BaseModel
2+
3+
from infrahub.core.constants import RelationshipCardinality
4+
from infrahub.core.schema import NodeSchema
5+
6+
7+
class SchemaMappingValue(BaseModel):
8+
is_mandatory: bool
9+
source_field_name: str | None = None # None means there is no corresponding source field name
10+
relationship_cardinality: RelationshipCardinality | None = None
11+
12+
13+
SchemaMapping = dict[str, SchemaMappingValue]
14+
15+
16+
def get_schema_mapping(source_schema: NodeSchema, target_schema: NodeSchema) -> SchemaMapping:
17+
"""
18+
Return fields mapping meant to be used for converting a node from `source_kind` to `target_kind`.
19+
For any field of the target kind, field of the source kind will be matched if:
20+
- It's an attribute with identical name and type.
21+
- It's a relationship with identical peer kind and cardinality.
22+
If there is no match, the mapping will only indicate whether the field is mandatory or not.
23+
"""
24+
25+
target_field_to_source_field = {}
26+
27+
# Create lookup dictionaries for source attributes and relationships
28+
source_attrs = {attr.name: attr for attr in source_schema.attributes}
29+
source_rels = {rel.name: rel for rel in source_schema.relationships}
30+
31+
# Process attributes
32+
for target_attr in target_schema.attributes:
33+
source_attr = source_attrs.get(target_attr.name)
34+
if source_attr and source_attr.kind == target_attr.kind:
35+
target_field_to_source_field[target_attr.name] = SchemaMappingValue(
36+
source_field_name=source_attr.name, is_mandatory=not target_attr.optional
37+
)
38+
else:
39+
target_field_to_source_field[target_attr.name] = SchemaMappingValue(is_mandatory=not target_attr.optional)
40+
41+
# Process relationships
42+
for target_rel in target_schema.relationships:
43+
source_rel = source_rels.get(target_rel.name)
44+
if source_rel and source_rel.peer == target_rel.peer and source_rel.cardinality == target_rel.cardinality:
45+
target_field_to_source_field[target_rel.name] = SchemaMappingValue(
46+
source_field_name=source_rel.name,
47+
is_mandatory=not target_rel.optional,
48+
relationship_cardinality=target_rel.cardinality,
49+
)
50+
else:
51+
target_field_to_source_field[target_rel.name] = SchemaMappingValue(
52+
is_mandatory=not target_rel.optional,
53+
relationship_cardinality=target_rel.cardinality,
54+
)
55+
56+
return target_field_to_source_field

backend/infrahub/core/manager.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from copy import copy
34
from functools import reduce
45
from typing import TYPE_CHECKING, Any, Iterable, Literal, TypeVar, overload
56

@@ -1339,22 +1340,24 @@ async def delete(
13391340
nodes: list[Node],
13401341
branch: Branch | str | None = None,
13411342
at: Timestamp | str | None = None,
1343+
cascade_delete: bool = True,
13421344
) -> list[Node]:
13431345
"""Returns list of deleted nodes because of cascading deletes"""
13441346
branch = await registry.get_branch(branch=branch, db=db)
1345-
node_delete_validator = NodeDeleteValidator(db=db, branch=branch)
1346-
ids_to_delete = await node_delete_validator.get_ids_to_delete(nodes=nodes, at=at)
1347-
node_ids = {node.get_id() for node in nodes}
1348-
missing_ids_to_delete = ids_to_delete - node_ids
1349-
if missing_ids_to_delete:
1350-
node_map = await cls.get_many(db=db, ids=list(missing_ids_to_delete), branch=branch, at=at)
1351-
nodes += list(node_map.values())
1352-
deleted_nodes = []
1353-
for node in nodes:
1347+
nodes_to_delete = copy(nodes)
1348+
if cascade_delete:
1349+
node_delete_validator = NodeDeleteValidator(db=db, branch=branch)
1350+
ids_to_delete = await node_delete_validator.get_ids_to_delete(nodes=nodes, at=at)
1351+
node_ids = {node.get_id() for node in nodes}
1352+
missing_ids_to_delete = ids_to_delete - node_ids
1353+
if missing_ids_to_delete:
1354+
node_map = await cls.get_many(db=db, ids=list(missing_ids_to_delete), branch=branch, at=at)
1355+
nodes_to_delete += list(node_map.values())
1356+
1357+
for node in nodes_to_delete:
13541358
await node.delete(db=db, at=at)
1355-
deleted_nodes.append(node)
13561359

1357-
return deleted_nodes
1360+
return nodes_to_delete
13581361

13591362

13601363
def _get_cardinality_one_identifiers_by_kind(

backend/infrahub/core/node/__init__.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
from infrahub.core.constants.schema import SchemaElementPathType
2323
from infrahub.core.protocols import CoreNumberPool, CoreObjectTemplate
2424
from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery
25-
from infrahub.core.schema import AttributeSchema, NodeSchema, ProfileSchema, RelationshipSchema, TemplateSchema
25+
from infrahub.core.schema import (
26+
AttributeSchema,
27+
NodeSchema,
28+
NonGenericSchemaTypes,
29+
ProfileSchema,
30+
RelationshipSchema,
31+
TemplateSchema,
32+
)
2633
from infrahub.core.timestamp import Timestamp
2734
from infrahub.exceptions import InitializationError, NodeNotFoundError, PoolExhaustedError, ValidationError
2835
from infrahub.types import ATTRIBUTE_TYPES
@@ -66,7 +73,7 @@ def __init_subclass_with_meta__(cls, _meta=None, default_filter=None, **options)
6673
_meta.default_filter = default_filter
6774
super().__init_subclass_with_meta__(_meta=_meta, **options)
6875

69-
def get_schema(self) -> NodeSchema | ProfileSchema | TemplateSchema:
76+
def get_schema(self) -> NonGenericSchemaTypes:
7077
return self._schema
7178

7279
def get_kind(self) -> str:
@@ -872,7 +879,7 @@ def _get_parent_relationship_name(self) -> str | None:
872879
if relationship.kind == RelationshipKind.PARENT:
873880
return relationship.name
874881

875-
async def get_object_template(self, db: InfrahubDatabase) -> Node | None:
882+
async def get_object_template(self, db: InfrahubDatabase) -> CoreObjectTemplate | None:
876883
object_template: RelationshipManager = getattr(self, OBJECT_TEMPLATE_RELATIONSHIP_NAME, None)
877884
return await object_template.get_peer(db=db) if object_template is not None else None
878885

@@ -888,3 +895,8 @@ def get_relationships(
888895
for relationship in self.get_schema().relationships
889896
if relationship.name not in exclude and relationship.kind == kind
890897
]
898+
899+
def validate_relationships(self) -> None:
900+
for name in self._relationships:
901+
relm: RelationshipManager = getattr(self, name)
902+
relm.validate()

0 commit comments

Comments
 (0)