Skip to content

Commit 13972a2

Browse files
committed
Fix MANY_OBJ_LIST_DICT
1 parent 6f708c5 commit 13972a2

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

infrahub_sdk/spec/object.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ async def create_node(
229229
continue
230230

231231
if key in schema.relationship_names:
232-
rel_schema = schema.get_relationship(name=key)
233-
234232
rel_info = await get_relationship_info(
235233
client=client, schema=schema, key=key, value=value, branch=branch
236234
)
@@ -291,17 +289,13 @@ async def create_node(
291289
client.log.info(f"Node: {display_label}")
292290

293291
for rel in remaining_rels:
294-
# identify what is the name of the relationship on the other side
295-
rel_info = rels_info[rel]
296-
297-
if rel_schema.identifier is None:
298-
raise ValueError("identifier must be defined")
299-
300292
context = {}
293+
294+
# If there is a peer relationship, we add the node id to the context
295+
rel_info = rels_info[rel]
301296
if rel_info.peer_rel:
302297
context[rel_info.peer_rel.name] = node.id
303298

304-
# TODO need to account for the different format here
305299
await cls.create_related_nodes(
306300
client=client,
307301
rel_info=rel_info,
@@ -318,16 +312,18 @@ async def create_related_nodes(
318312
cls,
319313
client: InfrahubClient,
320314
rel_info: RelationshipInfo,
321-
data: dict,
315+
data: dict | list[dict],
322316
context: dict | None = None,
323317
branch: str | None = None,
324318
default_schema_kind: str | None = None,
325319
) -> list[InfrahubNode]:
326-
peer_schema = await client.schema.get(kind=rel_info.peer_kind, branch=branch)
327-
328320
nodes: list[InfrahubNode] = []
321+
context = context or {}
322+
323+
if isinstance(data, dict) and rel_info.format == RelationshipDataFormat.ONE_OBJ:
324+
peer_kind = data.get("kind") or rel_info.peer_kind
325+
peer_schema = await client.schema.get(kind=peer_kind, branch=branch)
329326

330-
if rel_info.format == RelationshipDataFormat.ONE_OBJ:
331327
node = await cls.create_node(
332328
client=client,
333329
schema=peer_schema,
@@ -338,8 +334,10 @@ async def create_related_nodes(
338334
)
339335
return [node]
340336

341-
if rel_info.format == RelationshipDataFormat.MANY_OBJ_DICT_LIST:
342-
context = context or {}
337+
if isinstance(data, dict) and rel_info.format == RelationshipDataFormat.MANY_OBJ_DICT_LIST:
338+
peer_kind = data.get("kind") or rel_info.peer_kind
339+
peer_schema = await client.schema.get(kind=peer_kind, branch=branch)
340+
343341
for idx, peer_data in enumerate(data["data"]):
344342
context["list_index"] = idx
345343
if isinstance(peer_data, dict):
@@ -354,8 +352,13 @@ async def create_related_nodes(
354352
nodes.append(node)
355353
return nodes
356354

357-
if rel_info.format == RelationshipDataFormat.MANY_OBJ_LIST_DICT:
358-
for item in data:
355+
if isinstance(data, list) and rel_info.format == RelationshipDataFormat.MANY_OBJ_LIST_DICT:
356+
for idx, item in enumerate(data):
357+
context["list_index"] = idx
358+
359+
peer_kind = item.get("kind") or rel_info.peer_kind
360+
peer_schema = await client.schema.get(kind=peer_kind, branch=branch)
361+
359362
node = await cls.create_node(
360363
client=client,
361364
schema=peer_schema,
@@ -365,6 +368,7 @@ async def create_related_nodes(
365368
default_schema_kind=default_schema_kind,
366369
)
367370
nodes.append(node)
371+
368372
return nodes
369373

370374
raise ValueError(

tests/fixtures/spec_objects/animal_person02.yml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,32 @@ spec:
88
height: 180
99
tags:
1010
data:
11-
- name: Dog Lover # New node will be created if it doesn't exist
11+
- name: Dog Lover
1212

13-
animals: # Relationship of cardinality many that will create/update nodes as needed
14-
# The owner of the dog will be automatically inserted as long as their is a bidirectional relationship
15-
# between TestingPerson and TestingDog
13+
animals:
1614
kind: TestingDog
1715
data:
1816
- name: Max
1917
weight: 25
2018
breed: Golden Retriever
2119
color: "#FFD700"
2220

21+
- name: Emily Parker
22+
height: 165
23+
animals:
24+
- kind: TestingDog
25+
data:
26+
name: Max
27+
weight: 25
28+
breed: Golden Retriever
29+
color: "#FFD700"
30+
- kind: TestingCat
31+
data:
32+
name: Whiskers
33+
weight: 10
34+
breed: Siamese
35+
color: "#FFD700"
36+
2337
- name: Mike Johnson
2438
height: 175
2539
best_friends: # Relationship of cardinality many that referenced existing nodes based on their HFID

tests/integration/test_spec_object.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ async def test_load_persons02(self, client: InfrahubClient, branch_name: str, in
8989
await obj_file.validate_format(client=client, branch=branch_name)
9090

9191
# Check that the nodes are not present in the database before loading the file
92-
assert len(await client.all(kind=obj_file.spec.kind, branch=branch_name)) == 3
92+
assert len(await client.all(kind=obj_file.spec.kind, branch=branch_name)) == 4
9393

9494
await obj_file.process(client=client, branch=branch_name)
9595

9696
persons = await client.all(kind=obj_file.spec.kind, branch=branch_name)
9797
person_by_name = {"__".join(person.get_human_friendly_id()): person for person in persons}
98-
assert len(persons) == 5
98+
assert len(persons) == 6
9999

100100
# Validate that the best_friends relationship is correctly populated
101101
await person_by_name["Mike Johnson"].best_friends.fetch()
@@ -110,3 +110,8 @@ async def test_load_persons02(self, client: InfrahubClient, branch_name: str, in
110110
await person_by_name["Mike Johnson"].tags.fetch()
111111
tags_mike = [tag.hfid for tag in person_by_name["Mike Johnson"].tags.peers]
112112
assert sorted(tags_mike) == sorted([["Veterinarian"], ["Breeder"]])
113+
114+
# Validate that animals for Emily Parler have been correctly created
115+
await person_by_name["Emily Parker"].animals.fetch()
116+
animals_emily = [animal.display_label for animal in person_by_name["Emily Parker"].animals.peers]
117+
assert sorted(animals_emily) == sorted(["Max Golden Retriever", "Whiskers Siamese #FFD700"])

0 commit comments

Comments
 (0)