Skip to content

Commit 17c1f6a

Browse files
authored
Fix filters include parameter (#513)
1 parent 6638c90 commit 17c1f6a

File tree

5 files changed

+93
-12
lines changed

5 files changed

+93
-12
lines changed

infrahub_sdk/client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ async def _process_nodes_and_relationships(
539539
schema_kind: str,
540540
branch: str,
541541
prefetch_relationships: bool,
542+
include: list[str] | None,
542543
timeout: int | None = None,
543544
) -> ProcessRelationsNode:
544545
"""Processes InfrahubNode and their Relationships from the GraphQL query response.
@@ -563,9 +564,12 @@ async def _process_nodes_and_relationships(
563564
node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
564565
nodes.append(node)
565566

566-
if prefetch_relationships:
567+
if prefetch_relationships or (include and any(rel in include for rel in node._relationships)):
567568
await node._process_relationships(
568-
node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout
569+
node_data=item,
570+
branch=branch,
571+
related_nodes=related_nodes,
572+
timeout=timeout,
569573
)
570574

571575
return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes)
@@ -815,6 +819,7 @@ async def process_page(page_offset: int, page_number: int) -> tuple[dict, Proces
815819
branch=branch,
816820
prefetch_relationships=prefetch_relationships,
817821
timeout=timeout,
822+
include=include,
818823
)
819824
return response, process_result
820825

@@ -1829,6 +1834,7 @@ def _process_nodes_and_relationships(
18291834
schema_kind: str,
18301835
branch: str,
18311836
prefetch_relationships: bool,
1837+
include: list[str] | None,
18321838
timeout: int | None = None,
18331839
) -> ProcessRelationsNodeSync:
18341840
"""Processes InfrahubNodeSync and their Relationships from the GraphQL query response.
@@ -1853,7 +1859,7 @@ def _process_nodes_and_relationships(
18531859
node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
18541860
nodes.append(node)
18551861

1856-
if prefetch_relationships:
1862+
if prefetch_relationships or (include and any(rel in include for rel in node._relationships)):
18571863
node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout)
18581864

18591865
return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes)
@@ -1978,6 +1984,7 @@ def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelat
19781984
branch=branch,
19791985
prefetch_relationships=prefetch_relationships,
19801986
timeout=timeout,
1987+
include=include,
19811988
)
19821989
return response, process_result
19831990

infrahub_sdk/node/node.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -748,12 +748,11 @@ async def generate_query_data_node(
748748
continue
749749

750750
peer_data: dict[str, Any] = {}
751-
if rel_schema and prefetch_relationships:
751+
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
752+
if rel_schema and should_fetch_relationship:
752753
peer_schema = await self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
753754
peer_node = InfrahubNode(client=self._client, schema=peer_schema, branch=self._branch)
754755
peer_data = await peer_node.generate_query_data_node(
755-
include=include,
756-
exclude=exclude,
757756
property=property,
758757
)
759758

@@ -892,7 +891,11 @@ async def update(
892891
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
893892

894893
async def _process_relationships(
895-
self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNode], timeout: int | None = None
894+
self,
895+
node_data: dict[str, Any],
896+
branch: str,
897+
related_nodes: list[InfrahubNode],
898+
timeout: int | None = None,
896899
) -> None:
897900
"""Processes the Relationships of a InfrahubNode and add Related Nodes to a list.
898901
@@ -1369,7 +1372,8 @@ def generate_query_data_node(
13691372
continue
13701373

13711374
peer_data: dict[str, Any] = {}
1372-
if rel_schema and prefetch_relationships:
1375+
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
1376+
if rel_schema and should_fetch_relationship:
13731377
peer_schema = self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
13741378
peer_node = InfrahubNodeSync(client=self._client, schema=peer_schema, branch=self._branch)
13751379
peer_data = peer_node.generate_query_data_node(include=include, exclude=exclude, property=property)

tests/integration/test_node.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,40 @@ async def test_node_create_with_relationships_using_related_node(
8585
assert node_after.owner.peer.id == person_joe.id
8686
assert node_after.owner.peer.typename == "TestingPerson"
8787

88+
async def test_node_filters_include(
89+
self,
90+
default_branch: str,
91+
client: InfrahubClient,
92+
initial_schema: None,
93+
manufacturer_mercedes,
94+
person_joe,
95+
tag_red,
96+
) -> None:
97+
car = await client.create(
98+
kind=TESTING_CAR,
99+
name="Tiguan2",
100+
color="Black",
101+
manufacturer=manufacturer_mercedes,
102+
owner=person_joe,
103+
tags=[tag_red],
104+
)
105+
await car.save(allow_upsert=True)
106+
assert car.id is not None
107+
108+
# Clear store, as when we call `owner.peer`, we actually rely on the peer having being stored in store.
109+
client.store._branches = {}
110+
node_after = await client.get(kind=TESTING_CAR, id=car.id)
111+
112+
with pytest.raises(NodeNotFoundError, match=f"Unable to find the node '{person_joe.id}' in the store"):
113+
_ = node_after.owner.peer
114+
115+
assert len(node_after.tags.peers) == 0
116+
117+
# Test both one and many relationships
118+
node_after = await client.get(kind=TESTING_CAR, id=car.id, include=["tags", "owner"])
119+
assert [tag.id for tag in node_after.tags.peers] == [tag_red.id]
120+
assert node_after.owner.peer.id == person_joe.id, f"{person_joe.id=}"
121+
88122
async def test_node_update_with_original_data(
89123
self,
90124
default_branch: str,

tests/unit/sdk/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ async def client() -> InfrahubClient:
3333
return InfrahubClient(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))
3434

3535

36+
@pytest.fixture
37+
async def client_sync() -> InfrahubClientSync:
38+
return InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))
39+
40+
3641
@pytest.fixture
3742
async def clients() -> BothClients:
3843
both = BothClients(

tests/unit/sdk/test_node.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@
5353
]
5454

5555

56+
async def set_builtin_tag_schema_cache(client) -> None:
57+
# Set tag schema in cache to avoid needed to request the server.
58+
builtin_tag_schema = {
59+
"version": "1.0",
60+
"nodes": [
61+
{
62+
"name": "Tag",
63+
"namespace": "Builtin",
64+
"default_filter": "name__value",
65+
"display_label": "name__value",
66+
"branch": "aware",
67+
}
68+
],
69+
}
70+
client.schema.set_cache(builtin_tag_schema)
71+
72+
5673
async def test_method_sanity() -> None:
5774
"""Validate that there is at least one public method and that both clients look the same."""
5875
assert async_node_methods
@@ -1055,12 +1072,19 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client
10551072

10561073

10571074
@pytest.mark.parametrize("client_type", client_types)
1058-
async def test_query_data_include_property(client, location_schema: NodeSchemaAPI, client_type) -> None:
1075+
async def test_query_data_include_property(
1076+
client,
1077+
client_sync,
1078+
location_schema: NodeSchemaAPI,
1079+
client_type,
1080+
) -> None:
10591081
if client_type == "standard":
1082+
await set_builtin_tag_schema_cache(client)
10601083
node = InfrahubNode(client=client, schema=location_schema)
10611084
data = await node.generate_query_data(include=["tags"], property=True)
10621085
else:
1063-
node = InfrahubNodeSync(client=client, schema=location_schema)
1086+
await set_builtin_tag_schema_cache(client_sync)
1087+
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
10641088
data = node.generate_query_data(include=["tags"], property=True)
10651089

10661090
assert data == {
@@ -1178,12 +1202,19 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP
11781202

11791203

11801204
@pytest.mark.parametrize("client_type", client_types)
1181-
async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type) -> None:
1205+
async def test_query_data_include(
1206+
client,
1207+
client_sync,
1208+
location_schema: NodeSchemaAPI,
1209+
client_type,
1210+
) -> None:
11821211
if client_type == "standard":
1212+
await set_builtin_tag_schema_cache(client)
11831213
node = InfrahubNode(client=client, schema=location_schema)
11841214
data = await node.generate_query_data(include=["tags"])
11851215
else:
1186-
node = InfrahubNodeSync(client=client, schema=location_schema)
1216+
await set_builtin_tag_schema_cache(client_sync)
1217+
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
11871218
data = node.generate_query_data(include=["tags"])
11881219

11891220
assert data == {

0 commit comments

Comments
 (0)