Skip to content

Commit 49f2b74

Browse files
committed
Add support for alternate format for relationship many
1 parent 3e09b85 commit 49f2b74

File tree

3 files changed

+171
-26
lines changed

3 files changed

+171
-26
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ dist/*
2727
**/*.csv
2828

2929
# Generated files
30-
generated/
30+
generated/
31+
sandbox/

infrahub_sdk/spec/object.py

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from enum import Enum
34
from typing import TYPE_CHECKING, Any
45

56
from pydantic import BaseModel, Field
@@ -14,16 +15,44 @@
1415
from ..schema import MainSchemaTypesAPI, RelationshipSchema
1516

1617

18+
def validate_list_of_scalars(value: list[Any]) -> bool:
19+
return all(isinstance(item, (str, int, float, bool)) for item in value)
20+
21+
22+
def validate_list_of_hfids(value: list[Any]) -> bool:
23+
return all(isinstance(item, (str, list)) for item in value)
24+
25+
26+
def validate_list_of_data_dicts(value: list[Any]) -> bool:
27+
return all(isinstance(item, dict) and "data" in item for item in value)
28+
29+
30+
def validate_list_of_objects(value: list[Any]) -> bool:
31+
return all(isinstance(item, dict) for item in value)
32+
33+
34+
class RelationshipDataFormat(str, Enum):
35+
UNKNOWN = "unknown"
36+
37+
ONE_REF = "one_ref"
38+
ONE_OBJ = "one_obj"
39+
40+
MANY_OBJ_DICT_LIST = "many_obj_dict_list"
41+
MANY_OBJ_LIST_DICT = "many_obj_list_dict"
42+
MANY_REF = "many_ref_list"
43+
44+
1745
class RelationshipInfo(BaseModel):
1846
name: str
1947
rel_schema: RelationshipSchema
2048
peer_kind: str
2149
peer_rel: RelationshipSchema | None = None
22-
is_reference: bool = True
2350
reason_relationship_not_valid: str | None = None
51+
format: RelationshipDataFormat = RelationshipDataFormat.UNKNOWN
2452

2553
@property
2654
def is_bidirectional(self) -> bool:
55+
"""Indicate if a relationship with the same identifier exists on the other side"""
2756
return bool(self.peer_rel)
2857

2958
@property
@@ -36,6 +65,10 @@ def is_mandatory(self) -> bool:
3665
def is_valid(self) -> bool:
3766
return not self.reason_relationship_not_valid
3867

68+
@property
69+
def is_reference(self) -> bool:
70+
return self.format in [RelationshipDataFormat.ONE_REF, RelationshipDataFormat.MANY_REF]
71+
3972

4073
async def get_relationship_info(
4174
client: InfrahubClient, schema: MainSchemaTypesAPI, key: str, value: Any, branch: str | None = None
@@ -63,10 +96,38 @@ async def get_relationship_info(
6396
except ValueError:
6497
pass
6598

66-
# Check if the content of the relationship is a reference to existing objects
67-
# or if it contains the data to create/update related objects
68-
if isinstance(value, dict) and "data" in value:
69-
info.is_reference = False
99+
if rel_schema.cardinality == "one" and isinstance(value, list):
100+
# validate the list is composed of string
101+
if validate_list_of_scalars(value):
102+
info.format = RelationshipDataFormat.ONE_REF
103+
else:
104+
info.reason_relationship_not_valid = "Too many objects provided for a relationship of cardinality one"
105+
106+
elif rel_schema.cardinality == "one" and isinstance(value, dict) and "data" in value:
107+
info.format = RelationshipDataFormat.ONE_OBJ
108+
109+
elif (
110+
rel_schema.cardinality == "many"
111+
and isinstance(value, dict)
112+
and "data" in value
113+
and validate_list_of_objects(value["data"])
114+
):
115+
# Initial format, we need to support it for backward compatibility for menu
116+
# it's helpful if there is only one type of object to manage
117+
info.format = RelationshipDataFormat.MANY_OBJ_DICT_LIST
118+
119+
elif rel_schema.cardinality == "many" and isinstance(value, dict) and "data" not in value:
120+
info.reason_relationship_not_valid = "Invalid structure for a relationship of cardinality many,"
121+
" either provide a dict with data as a list or a list of objects"
122+
123+
elif rel_schema.cardinality == "many" and isinstance(value, list):
124+
if validate_list_of_data_dicts(value):
125+
info.format = RelationshipDataFormat.MANY_OBJ_LIST_DICT
126+
elif validate_list_of_hfids(value):
127+
info.format = RelationshipDataFormat.MANY_REF
128+
else:
129+
info.reason_relationship_not_valid = "Invalid structure for a relationship of cardinality many,"
130+
" either provide a list of dict with data or a list of hfids"
70131

71132
return info
72133

@@ -100,7 +161,7 @@ async def validate_object(
100161
context = context or {}
101162

102163
# First validate if all mandatory fields are present
103-
for element in schema.mandatory_attribute_names + schema.mandatory_relationship_names:
164+
for element in schema.mandatory_input_names:
104165
if not any([element in data.keys(), element in context.keys()]):
105166
errors.append(ValidationError(identifier=element, message=f"{element} is mandatory"))
106167

@@ -162,6 +223,7 @@ async def create_node(
162223
for key, value in data.items():
163224
if key in schema.attribute_names:
164225
clean_data[key] = value
226+
continue
165227

166228
if key in schema.relationship_names:
167229
rel_schema = schema.get_relationship(name=key)
@@ -181,22 +243,31 @@ async def create_node(
181243
# - if the relationship is not bidirectional, then we need to create the related object First
182244
if rel_info.is_reference and isinstance(value, list):
183245
clean_data[key] = value
184-
elif rel_info.is_reference and rel_schema.cardinality == "one" and isinstance(value, str):
246+
elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str):
185247
clean_data[key] = [value]
186248
elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory:
187249
remaining_rels.append(key)
188250
elif not rel_info.is_reference and not rel_info.is_mandatory:
189-
nodes = await cls.create_related_nodes(
190-
client=client,
191-
rel_info=rel_info,
192-
data=value["data"],
193-
branch=branch,
194-
default_schema_kind=default_schema_kind,
195-
)
196-
if rel_info.rel_schema.cardinality == "one":
251+
if rel_info.format == RelationshipDataFormat.ONE_OBJ:
252+
nodes = await cls.create_related_nodes(
253+
client=client,
254+
rel_info=rel_info,
255+
data=value,
256+
branch=branch,
257+
default_schema_kind=default_schema_kind,
258+
)
197259
clean_data[key] = nodes[0]
260+
198261
else:
262+
nodes = await cls.create_related_nodes(
263+
client=client,
264+
rel_info=rel_info,
265+
data=value,
266+
branch=branch,
267+
default_schema_kind=default_schema_kind,
268+
)
199269
clean_data[key] = nodes
270+
200271
else:
201272
raise ValueError(f"Situation unaccounted for: {rel_info}")
202273

@@ -223,16 +294,15 @@ async def create_node(
223294
if rel_schema.identifier is None:
224295
raise ValueError("identifier must be defined")
225296

226-
rel_data = data[rel]["data"]
227297
context = {}
228-
229298
if rel_info.peer_rel:
230299
context[rel_info.peer_rel.name] = node.id
231300

301+
# TODO need to account for the different format here
232302
await cls.create_related_nodes(
233303
client=client,
234304
rel_info=rel_info,
235-
data=rel_data,
305+
data=data[rel],
236306
context=context,
237307
branch=branch,
238308
default_schema_kind=default_schema_kind,
@@ -254,20 +324,20 @@ async def create_related_nodes(
254324

255325
nodes: list[InfrahubNode] = []
256326

257-
if rel_info.rel_schema.cardinality == "one" and isinstance(data, dict):
327+
if rel_info.format == RelationshipDataFormat.ONE_OBJ:
258328
node = await cls.create_node(
259329
client=client,
260330
schema=peer_schema,
261-
data=data,
331+
data=data["data"],
262332
context=context,
263333
branch=branch,
264334
default_schema_kind=default_schema_kind,
265335
)
266336
return [node]
267337

268-
if rel_info.rel_schema.cardinality == "many" and isinstance(data, list):
338+
if rel_info.format == RelationshipDataFormat.MANY_OBJ_DICT_LIST:
269339
context = context or {}
270-
for idx, peer_data in enumerate(data):
340+
for idx, peer_data in enumerate(data["data"]):
271341
context["list_index"] = idx
272342
if isinstance(peer_data, dict):
273343
node = await cls.create_node(
@@ -281,6 +351,19 @@ async def create_related_nodes(
281351
nodes.append(node)
282352
return nodes
283353

354+
if rel_info.format == RelationshipDataFormat.MANY_OBJ_LIST_DICT:
355+
for item in data:
356+
node = await cls.create_node(
357+
client=client,
358+
schema=peer_schema,
359+
data=item["data"],
360+
context=context,
361+
branch=branch,
362+
default_schema_kind=default_schema_kind,
363+
)
364+
nodes.append(node)
365+
return nodes
366+
284367
raise ValueError(
285368
f"Relationship {rel_info.rel_schema.name} doesn't have the right format {rel_info.rel_schema.cardinality} / {type(data)}"
286369
)

tests/unit/sdk/spec/test_object.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
15
import pytest
2-
from pytest_httpx import HTTPXMock
36

4-
from infrahub_sdk.client import InfrahubClient
57
from infrahub_sdk.exceptions import ValidationError
6-
from infrahub_sdk.spec.object import ObjectFile
8+
from infrahub_sdk.spec.object import ObjectFile, RelationshipDataFormat, get_relationship_info
9+
10+
if TYPE_CHECKING:
11+
from pytest_httpx import HTTPXMock
12+
13+
from infrahub_sdk.client import InfrahubClient
714

815

916
@pytest.fixture
@@ -61,3 +68,57 @@ async def test_validate_object_bad_syntax02(
6168
await obj.validate_format(client=client)
6269

6370
assert "notvalidattribute" in str(exc.value)
71+
72+
73+
get_relationship_info_testdata = [
74+
pytest.param(
75+
[
76+
{"data": {"name": "Blue"}},
77+
{"data": {"name": "Red"}},
78+
],
79+
True,
80+
RelationshipDataFormat.MANY_OBJ_LIST_DICT,
81+
id="many_obj_list_dict",
82+
),
83+
pytest.param(
84+
{
85+
"data": [
86+
{"name": "Blue"},
87+
{"name": "Red"},
88+
],
89+
},
90+
True,
91+
RelationshipDataFormat.MANY_OBJ_DICT_LIST,
92+
id="many_obj_dict_list",
93+
),
94+
pytest.param(
95+
["blue", "red"],
96+
True,
97+
RelationshipDataFormat.MANY_REF,
98+
id="many_ref",
99+
),
100+
pytest.param(
101+
[
102+
{"name": "Blue"},
103+
{"name": "Red"},
104+
],
105+
False,
106+
RelationshipDataFormat.UNKNOWN,
107+
id="many_invalid_list_dict",
108+
),
109+
]
110+
111+
112+
@pytest.mark.parametrize("data,is_valid,format", get_relationship_info_testdata)
113+
async def test_get_relationship_info_tags(
114+
client: InfrahubClient,
115+
mock_schema_query_01: HTTPXMock,
116+
data: dict | list,
117+
is_valid: bool,
118+
format: RelationshipDataFormat,
119+
):
120+
location_schema = await client.schema.get(kind="BuiltinLocation")
121+
122+
rel_info = await get_relationship_info(client, location_schema, "tags", data)
123+
assert rel_info.is_valid == is_valid
124+
assert rel_info.format == format

0 commit comments

Comments
 (0)