Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 24 additions & 58 deletions backend/infrahub/core/convert_object_type/object_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Self, assert_never
from typing import Any, assert_never

from pydantic import BaseModel, model_validator
from infrahub_sdk.convert_object_type import ConversionFieldInput, ConversionFieldValue

from infrahub.core.attribute import BaseAttribute
from infrahub.core.branch import Branch
Expand All @@ -20,52 +20,14 @@
from infrahub.workers.dependencies import get_message_bus


class InputDataForDestField(BaseModel): # Only one of these fields can be not None
attribute_value: Any | None = None
peer_id: str | None = None
peers_ids: list[str] | None = None

@model_validator(mode="after")
def check_only_one_field(self) -> Self:
fields = [self.attribute_value, self.peer_id, self.peers_ids]
set_fields = [f for f in fields if f is not None]
if len(set_fields) != 1:
raise ValueError("Exactly one of attribute_value, peer_id, or peers_ids must be set")
return self

@property
def value(self) -> Any:
if self.attribute_value is not None:
return self.attribute_value
if self.peer_id is not None:
return self.peer_id
if self.peers_ids is not None:
return self.peers_ids

raise ValueError(
"Exactly one of attribute_value, peer_id, or peers_ids must be set, model has not been validated correctly."
)


class InputForDestField(BaseModel): # Only one of these fields can be not None
source_field: str | None = None
data: InputDataForDestField | None = None

@model_validator(mode="after")
def check_only_one_field(self) -> Self:
if self.source_field is not None and self.data is not None:
raise ValueError("Only one of source_field or data can be set")
if self.source_field is None and self.data is None:
raise ValueError("Either source_field or data must be set")
return self

@property
def value(self) -> Any:
if self.source_field is not None:
return self.source_field
if self.data is not None:
return self.data
raise ValueError("Either source_field or data must be set, model has not been validated correctly.")
def _get_conversion_field_raw_value(conv_field_value: ConversionFieldValue) -> Any:
if conv_field_value.attribute_value is not None:
return conv_field_value.attribute_value
if conv_field_value.peer_id is not None:
return conv_field_value.peer_id
if conv_field_value.peers_ids is not None:
return conv_field_value.peers_ids
raise ValueError("ConversionFieldValue has not been validated correctly.")
Comment on lines +23 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add docstring for _get_conversion_field_raw_value.
New helpers must ship with Google-style docstrings per our Python guidelines.

Apply this diff:

-def _get_conversion_field_raw_value(conv_field_value: ConversionFieldValue) -> Any:
+def _get_conversion_field_raw_value(conv_field_value: ConversionFieldValue) -> Any:
+    """Extract the raw payload carried by a ConversionFieldValue.
+
+    Args:
+        conv_field_value: Conversion data for the destination field.
+
+    Returns:
+        The attribute value, peer identifier, or list of peer identifiers.
+
+    Raises:
+        ValueError: If the conversion input does not define any payload.
+    """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_conversion_field_raw_value(conv_field_value: ConversionFieldValue) -> Any:
if conv_field_value.attribute_value is not None:
return conv_field_value.attribute_value
if conv_field_value.peer_id is not None:
return conv_field_value.peer_id
if conv_field_value.peers_ids is not None:
return conv_field_value.peers_ids
raise ValueError("ConversionFieldValue has not been validated correctly.")
def _get_conversion_field_raw_value(conv_field_value: ConversionFieldValue) -> Any:
"""Extract the raw payload carried by a ConversionFieldValue.
Args:
conv_field_value: Conversion data for the destination field.
Returns:
The attribute value, peer identifier, or list of peer identifiers.
Raises:
ValueError: If the conversion input does not define any payload.
"""
if conv_field_value.attribute_value is not None:
return conv_field_value.attribute_value
if conv_field_value.peer_id is not None:
return conv_field_value.peer_id
if conv_field_value.peers_ids is not None:
return conv_field_value.peers_ids
raise ValueError("ConversionFieldValue has not been validated correctly.")
🤖 Prompt for AI Agents
In backend/infrahub/core/convert_object_type/object_conversion.py around lines
23 to 30, the helper function _get_conversion_field_raw_value lacks a
Google-style docstring; add a docstring immediately below the def that documents
the function purpose, the conv_field_value parameter (type
ConversionFieldValue), the return value (Any) describing possible returned
fields (attribute_value, peer_id, peers_ids), and the ValueError raised when
none are set, following our Google-style format (Args, Returns, Raises) and
keeping it concise.



async def get_out_rels_peers_ids(node: Node, db: InfrahubDatabase, at: Timestamp) -> list[str]:
Expand All @@ -77,15 +39,15 @@ async def get_out_rels_peers_ids(node: Node, db: InfrahubDatabase, at: Timestamp
return all_peers_ids


async def build_data_new_node(db: InfrahubDatabase, mapping: dict[str, InputForDestField], node: Node) -> dict:
async def build_data_new_node(db: InfrahubDatabase, mapping: dict[str, ConversionFieldInput], node: Node) -> dict:
"""Value of a given field on the target kind to convert is either an input source attribute/relationship of the source node,
or a raw value."""

data = {}
for dest_field_name, input_for_dest_field in mapping.items():
value = input_for_dest_field.value
if isinstance(value, str): # source_field
item = getattr(node, value)
for dest_field_name, conv_field_input in mapping.items():
if conv_field_input.source_field is not None:
# Fetch the value of the corresponding field from the node being converted.
item = getattr(node, conv_field_input.source_field)
if isinstance(item, BaseAttribute):
data[dest_field_name] = item.value
elif isinstance(item, RelationshipManager):
Expand All @@ -98,8 +60,12 @@ async def build_data_new_node(db: InfrahubDatabase, mapping: dict[str, InputForD
data[dest_field_name] = [{"id": peer.id} for _, peer in (await item.get_peers(db=db)).items()]
else:
assert_never(item.schema.cardinality)
else: # user input data
data[dest_field_name] = value.value
elif conv_field_input.data is not None:
data[dest_field_name] = _get_conversion_field_raw_value(conv_field_input.data)
elif conv_field_input.use_default_value is True:
pass # default value will be used automatically when creating the node
else:
raise ValueError("ConversionFieldInput has not been validated correctly.")
return data


Expand Down Expand Up @@ -129,7 +95,7 @@ async def _get_other_active_branches(db: InfrahubDatabase) -> list[Branch]:
return [branch for branch in branches if not (branch.is_global or branch.is_default)]


def _has_pass_thru_aware_attributes(node_schema: NodeSchema, mapping: dict[str, InputForDestField]) -> bool:
def _has_pass_thru_aware_attributes(node_schema: NodeSchema, mapping: dict[str, ConversionFieldInput]) -> bool:
aware_attributes = [attr for attr in node_schema.attributes if attr.branch != BranchSupportType.AGNOSTIC]
aware_attributes_pass_thru = [
attr.name for attr in aware_attributes if attr.name in mapping and mapping[attr.name].source_field is not None
Expand Down Expand Up @@ -157,7 +123,7 @@ async def validate_conversion(
async def convert_and_validate_object_type(
node: Node,
target_schema: NodeSchema,
mapping: dict[str, InputForDestField],
mapping: dict[str, ConversionFieldInput],
branch: Branch,
db: InfrahubDatabase,
) -> Node:
Expand All @@ -180,7 +146,7 @@ async def convert_and_validate_object_type(
async def convert_object_type(
node: Node,
target_schema: NodeSchema,
mapping: dict[str, InputForDestField],
mapping: dict[str, ConversionFieldInput],
branch: Branch,
db: InfrahubDatabase,
) -> Node:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from infrahub.core.branch import Branch
from infrahub.core.constants.infrahubkind import REPOSITORYVALIDATOR, USERVALIDATOR
from infrahub.core.convert_object_type.object_conversion import (
InputForDestField,
ConversionFieldInput,
convert_object_type,
validate_conversion,
)
Expand All @@ -20,7 +20,7 @@
async def convert_repository_type(
repository: CoreRepository | CoreReadOnlyRepository,
target_schema: NodeSchema,
mapping: dict[str, InputForDestField],
mapping: dict[str, ConversionFieldInput],
branch: Branch,
db: InfrahubDatabase,
repository_post_creator: RepositoryFinalizer,
Expand Down
8 changes: 4 additions & 4 deletions backend/infrahub/graphql/mutations/convert_object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from infrahub.core import registry
from infrahub.core.constants.infrahubkind import READONLYREPOSITORY, REPOSITORY
from infrahub.core.convert_object_type.object_conversion import InputForDestField, convert_and_validate_object_type
from infrahub.core.convert_object_type.object_conversion import ConversionFieldInput, convert_and_validate_object_type
from infrahub.core.convert_object_type.repository_conversion import convert_repository_type
from infrahub.core.convert_object_type.schema_mapping import get_schema_mapping
from infrahub.core.manager import NodeManager
Expand Down Expand Up @@ -47,12 +47,12 @@ async def mutate(
source_schema = registry.get_node_schema(name=node_to_convert.get_kind(), branch=graphql_context.branch)
target_schema = registry.get_node_schema(name=str(data.target_kind), branch=graphql_context.branch)

fields_mapping: dict[str, InputForDestField] = {}
fields_mapping: dict[str, ConversionFieldInput] = {}
if not isinstance(data.fields_mapping, dict):
raise ValueError(f"Expected `fields_mapping` to be a `dict`, got {type(data.fields_mapping)}")

for field_name, input_for_dest_field_str in data.fields_mapping.items():
fields_mapping[field_name] = InputForDestField(**input_for_dest_field_str)
fields_mapping[field_name] = ConversionFieldInput(**input_for_dest_field_str)

node_to_convert = await NodeManager.get_one(
id=str(data.node_id), db=graphql_context.db, branch=graphql_context.branch
Expand All @@ -62,7 +62,7 @@ async def mutate(
mapping = get_schema_mapping(source_schema=source_schema, target_schema=target_schema)
for field_name, mapping_value in mapping.items():
if mapping_value.source_field_name is not None and field_name not in fields_mapping:
fields_mapping[field_name] = InputForDestField(source_field=mapping_value.source_field_name)
fields_mapping[field_name] = ConversionFieldInput(source_field=mapping_value.source_field_name)

if target_schema.kind in [REPOSITORY, READONLYREPOSITORY]:
new_node = await convert_repository_type(
Expand Down
1 change: 1 addition & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,7 @@ async def schemas_conversion(db: InfrahubDatabase, node_group_schema, data_schem
"attributes": [
{"name": "name", "kind": "Text", "unique": True},
{"name": "height", "kind": "Number", "optional": True},
{"name": "favorite_color", "kind": "Text", "optional": True, "default_value": "blue"},
],
"relationships": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import TYPE_CHECKING

import pytest
from infrahub_sdk.convert_object_type import ConversionFieldInput, ConversionFieldValue

from infrahub.core.constants.infrahubkind import NUMBERPOOL
from infrahub.core.convert_object_type.object_conversion import InputDataForDestField, InputForDestField
from infrahub.core.query.resource_manager import NumberPoolGetReserved
from infrahub.core.schema import AttributeSchema, GenericSchema, NodeSchema, SchemaRoot
from tests.helpers.test_app import TestInfrahubApp
Expand Down Expand Up @@ -59,6 +59,11 @@ async def test_get_fields_mapping(self, client: InfrahubClient, schemas_conversi
"age": {"is_mandatory": True, "source_field_name": None, "relationship_cardinality": None},
"name": {"is_mandatory": True, "source_field_name": "name", "relationship_cardinality": None},
"height": {"is_mandatory": False, "source_field_name": "height", "relationship_cardinality": None},
"favorite_color": {
"is_mandatory": False,
"source_field_name": "favorite_color",
"relationship_cardinality": None,
},
"subscriber_of_groups": {
"is_mandatory": False,
"source_field_name": "subscriber_of_groups",
Expand Down Expand Up @@ -112,18 +117,20 @@ async def test_convert_object_type(self, client: InfrahubClient, schemas_convers
kind="TestconvPerson1",
name="Jack",
height=170,
favorite_color="green",
favorite_car=car_1,
fastest_cars=[car_1, car_2],
)
await jack_1.save()

mapping = {
"name": InputForDestField(source_field="name"),
"age": InputForDestField(data=InputDataForDestField(attribute_value=25)),
"worst_car": InputForDestField(data=InputDataForDestField(peer_id=car_1.id)),
"fastest_cars": InputForDestField(source_field="fastest_cars"),
"slowest_cars": InputForDestField(data=InputDataForDestField(peers_ids=[car_1.id])),
"bags": InputForDestField(data=InputDataForDestField(peers_ids=[])),
"name": ConversionFieldInput(source_field="name"),
"age": ConversionFieldInput(data=ConversionFieldValue(attribute_value=25)),
"worst_car": ConversionFieldInput(data=ConversionFieldValue(peer_id=car_1.id)),
"fastest_cars": ConversionFieldInput(source_field="fastest_cars"),
"slowest_cars": ConversionFieldInput(data=ConversionFieldValue(peers_ids=[car_1.id])),
"bags": ConversionFieldInput(data=ConversionFieldValue(peers_ids=[])),
"favorite_color": ConversionFieldInput(use_default_value=True),
}

mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in mapping.items()}
Expand All @@ -143,6 +150,7 @@ async def test_convert_object_type(self, client: InfrahubClient, schemas_convers
assert res_node["age"]["value"] == 25
assert res_node["name"]["value"] == "Jack"
assert res_node["height"]["value"] == 170
assert res_node["favorite_color"]["value"] == "blue"


class TestConvertObjectTypeResourcePool(TestInfrahubApp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from infrahub.core.branch.enums import BranchStatus
from infrahub.core.branch.models import Branch
from infrahub.core.convert_object_type.object_conversion import InputDataForDestField, InputForDestField
from infrahub.core.convert_object_type.object_conversion import ConversionFieldInput, ConversionFieldValue
from infrahub.core.initialization import create_branch
from infrahub.core.node import Node
from infrahub.core.query.delete import DeleteAfterTimeQuery
Expand Down Expand Up @@ -210,11 +210,11 @@ async def test_convert_repo_to_read_only(
mapping = {}
for field_name, field_infos in conversion_response["FieldsMappingTypeConversion"]["mapping"].items():
if field_infos["source_field_name"] is not None:
mapping[field_name] = InputForDestField(source_field=field_infos["source_field_name"])
mapping[field_name] = ConversionFieldInput(source_field=field_infos["source_field_name"])
else:
assert field_name == "ref"

mapping["ref"] = InputForDestField(data=InputDataForDestField(attribute_value=repository.commit.value))
mapping["ref"] = ConversionFieldInput(data=ConversionFieldValue(attribute_value=repository.commit.value))
mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in mapping.items()}

with patch("infrahub.git.tasks.lock"):
Expand Down Expand Up @@ -356,11 +356,11 @@ async def test_convert_read_only_to_read_write(
mapping = {}
for field_name, field_infos in conversion_response["FieldsMappingTypeConversion"]["mapping"].items():
if field_infos["source_field_name"] is not None:
mapping[field_name] = InputForDestField(source_field=field_infos["source_field_name"])
mapping[field_name] = ConversionFieldInput(source_field=field_infos["source_field_name"])
else:
assert field_name == "default_branch"

mapping["default_branch"] = InputForDestField(data=InputDataForDestField(attribute_value=default_branch.name))
mapping["default_branch"] = ConversionFieldInput(data=ConversionFieldValue(attribute_value=default_branch.name))
mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in mapping.items()}

with patch("infrahub.git.tasks.lock"):
Expand Down Expand Up @@ -495,11 +495,11 @@ async def test_convert_to_read_write_on_main_create_branch_before(
mapping = {}
for field_name, field_infos in conversion_response["FieldsMappingTypeConversion"]["mapping"].items():
if field_infos["source_field_name"] is not None:
mapping[field_name] = InputForDestField(source_field=field_infos["source_field_name"])
mapping[field_name] = ConversionFieldInput(source_field=field_infos["source_field_name"])
else:
assert field_name == "default_branch"

mapping["default_branch"] = InputForDestField(data=InputDataForDestField(attribute_value=default_branch.name))
mapping["default_branch"] = ConversionFieldInput(data=ConversionFieldValue(attribute_value=default_branch.name))
mapping_dict = {field_name: model.model_dump(mode="json") for field_name, model in mapping.items()}

with patch("infrahub.git.tasks.lock"):
Expand Down
Loading