|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +from pydantic import BaseModel |
| 4 | + |
| 5 | +from infrahub.core.attribute import BaseAttribute |
| 6 | +from infrahub.core.branch import Branch |
| 7 | +from infrahub.core.constants import RelationshipCardinality |
| 8 | +from infrahub.core.manager import NodeManager |
| 9 | +from infrahub.core.node import Node |
| 10 | +from infrahub.core.node.create import create_node |
| 11 | +from infrahub.core.query.relationship import GetAllPeersIds |
| 12 | +from infrahub.core.relationship import RelationshipManager |
| 13 | +from infrahub.core.schema import NodeSchema |
| 14 | +from infrahub.database import InfrahubDatabase |
| 15 | + |
| 16 | + |
| 17 | +class InputDataForDestField(BaseModel): # Only one of these fields can be not None |
| 18 | + attribute_value: Any | None = None |
| 19 | + peer_id: str | None = None |
| 20 | + peers_ids: list[str] | None = None |
| 21 | + |
| 22 | + @property |
| 23 | + def value(self) -> Any: |
| 24 | + fields = [self.attribute_value, self.peer_id, self.peers_ids] |
| 25 | + set_fields = [f for f in fields if f is not None] |
| 26 | + if len(set_fields) != 1: |
| 27 | + raise ValueError("Exactly one of attribute_value, peer_id, or peers_ids must be set") |
| 28 | + return set_fields[0] |
| 29 | + |
| 30 | + |
| 31 | +class InputForDestField(BaseModel): # Only one of these fields can be not None |
| 32 | + source_field: str | None = None |
| 33 | + data: InputDataForDestField | None = None |
| 34 | + |
| 35 | + @property |
| 36 | + def value(self) -> Any: |
| 37 | + if self.source_field is not None and self.data is not None: |
| 38 | + raise ValueError("Only one of source_field or data can be set") |
| 39 | + if self.source_field is None and self.data is None: |
| 40 | + raise ValueError("Either source_field or data must be set") |
| 41 | + return self.source_field if self.source_field is not None else self.data |
| 42 | + |
| 43 | + |
| 44 | +async def get_out_rels_peers_ids(node: Node, db: InfrahubDatabase) -> list[str]: |
| 45 | + all_peers: list[Node] = [] |
| 46 | + for name in node._relationships: |
| 47 | + relm: RelationshipManager = getattr(node, name) |
| 48 | + peers = await relm.get_peers(db=db) |
| 49 | + all_peers.extend(peers.values()) |
| 50 | + return [peer.id for peer in all_peers] |
| 51 | + |
| 52 | + |
| 53 | +async def build_data_new_node(db: InfrahubDatabase, mapping: dict[str, InputForDestField], node: Node) -> dict: |
| 54 | + """Value of a given field on the target kind to convert is either an input source attribute/relationship of the source node, |
| 55 | + or a raw value.""" |
| 56 | + |
| 57 | + data = {} |
| 58 | + for dest_field_name, input_for_dest_field in mapping.items(): |
| 59 | + value = input_for_dest_field.value |
| 60 | + if isinstance(value, str): # source_field |
| 61 | + item = getattr(node, value) |
| 62 | + if isinstance(item, BaseAttribute): |
| 63 | + data[dest_field_name] = item.value |
| 64 | + elif isinstance(item, RelationshipManager): |
| 65 | + if item.schema.cardinality == RelationshipCardinality.ONE: |
| 66 | + peer = await item.get_peer(db=db) |
| 67 | + if peer is not None: |
| 68 | + data[dest_field_name] = {"id": peer.id} |
| 69 | + # else, relationship is optional, and if the target relationship is mandatory an error will be raised during creation |
| 70 | + elif item.schema.cardinality == RelationshipCardinality.MANY: |
| 71 | + data[dest_field_name] = [{"id": peer.id} for _, peer in (await item.get_peers(db=db)).items()] |
| 72 | + else: |
| 73 | + raise ValueError(f"Unknown cardinality {item.schema.cardinality=}") |
| 74 | + else: # user input data |
| 75 | + data[dest_field_name] = value.value |
| 76 | + return data |
| 77 | + |
| 78 | + |
| 79 | +async def get_unidirectional_rels_peers_ids(node: Node, branch: Branch, db: InfrahubDatabase) -> list[str]: |
| 80 | + """ |
| 81 | + Returns peers ids of nodes connected to input `node` through an incoming unidirectional relationship. |
| 82 | + """ |
| 83 | + |
| 84 | + out_rels_identifier = [rel.identifier for rel in node.get_schema().relationships] |
| 85 | + query = await GetAllPeersIds.init(db=db, node_id=node.id, branch=branch, exclude_identifiers=out_rels_identifier) |
| 86 | + await query.execute(db=db) |
| 87 | + return query.get_peers_uuids() |
| 88 | + |
| 89 | + |
| 90 | +async def convert_object_type( |
| 91 | + node: Node, target_schema: NodeSchema, mapping: dict[str, InputForDestField], branch: Branch, db: InfrahubDatabase |
| 92 | +) -> Node: |
| 93 | + """Delete the node and return the new created one. If creation fails, the node is not deleted, and raise an error. |
| 94 | + An extra check is performed on input node peers relationships to make sure they are still valid.""" |
| 95 | + |
| 96 | + node_schema = node.get_schema() |
| 97 | + if not isinstance(node_schema, NodeSchema): |
| 98 | + raise ValueError(f"Only a node with a NodeSchema can be converted, got {type(node_schema)}") |
| 99 | + |
| 100 | + async with db.start_transaction() as dbt: # noqa: PLR1702 |
| 101 | + deleted_node_out_rels_peer_ids = await get_out_rels_peers_ids(node=node, db=dbt) |
| 102 | + deleted_node_unidir_rels_peer_ids = await get_unidirectional_rels_peers_ids(node=node, db=dbt, branch=branch) |
| 103 | + |
| 104 | + deleted_nodes = await NodeManager.delete(db=dbt, branch=branch, nodes=[node], cascade_delete=False) |
| 105 | + if len(deleted_nodes) != 1: |
| 106 | + raise ValueError(f"Deleted {len(deleted_nodes)} nodes instead of 1") |
| 107 | + |
| 108 | + data_new_node = await build_data_new_node(dbt, mapping, node) |
| 109 | + node_created = await create_node( |
| 110 | + data=data_new_node, |
| 111 | + db=dbt, |
| 112 | + branch=branch, |
| 113 | + schema=target_schema, |
| 114 | + ) |
| 115 | + |
| 116 | + # Make sure relationships with constraints are not broken by retrieving them |
| 117 | + peers_ids = deleted_node_out_rels_peer_ids + deleted_node_unidir_rels_peer_ids |
| 118 | + peers = await NodeManager.get_many(ids=peers_ids, db=dbt, prefetch_relationships=True, branch=branch) |
| 119 | + for peer in peers.values(): |
| 120 | + peer.validate_relationships() |
| 121 | + |
| 122 | + return node_created |
0 commit comments