Skip to content

Commit f344e76

Browse files
committed
fix test include
1 parent 1cadd6f commit f344e76

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

infrahub_sdk/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,9 @@ async def _process_nodes_and_relationships(
558558
- 'related_nodes': A list of InfrahubNode objects representing the related nodes
559559
"""
560560

561+
# Ideally, include and relationships wouldn't be parameters of this method, they should only
562+
# be used to build the request for the server, and this method would build node according to the response.
563+
561564
nodes: list[InfrahubNode] = []
562565
related_nodes: list[InfrahubNode] = []
563566

@@ -571,7 +574,6 @@ async def _process_nodes_and_relationships(
571574
branch=branch,
572575
related_nodes=related_nodes,
573576
timeout=timeout,
574-
include=include,
575577
)
576578

577579
return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes)
@@ -1837,6 +1839,7 @@ def _process_nodes_and_relationships(
18371839
schema_kind: str,
18381840
branch: str,
18391841
prefetch_relationships: bool,
1842+
include: list[str] | None,
18401843
timeout: int | None = None,
18411844
) -> ProcessRelationsNodeSync:
18421845
"""Processes InfrahubNodeSync and their Relationships from the GraphQL query response.
@@ -1861,7 +1864,7 @@ def _process_nodes_and_relationships(
18611864
node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
18621865
nodes.append(node)
18631866

1864-
if prefetch_relationships:
1867+
if prefetch_relationships or include is not None:
18651868
node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout)
18661869

18671870
return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes)
@@ -1986,6 +1989,7 @@ def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelat
19861989
branch=branch,
19871990
prefetch_relationships=prefetch_relationships,
19881991
timeout=timeout,
1992+
include=include,
19891993
)
19901994
return response, process_result
19911995

infrahub_sdk/node/node.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,6 @@ async def _process_relationships(
890890
branch: str,
891891
related_nodes: list[InfrahubNode],
892892
timeout: int | None = None,
893-
include: list[str] | None = None,
894893
) -> None:
895894
"""Processes the Relationships of a InfrahubNode and add Related Nodes to a list.
896895
@@ -901,8 +900,6 @@ async def _process_relationships(
901900
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
902901
"""
903902
for rel_name in self._relationships:
904-
if include is not None and rel_name not in include:
905-
continue
906903
rel = getattr(self, rel_name)
907904
if rel and isinstance(rel, RelatedNode):
908905
relation = node_data["node"].get(rel_name, None)
@@ -1369,7 +1366,8 @@ def generate_query_data_node(
13691366
continue
13701367

13711368
peer_data: dict[str, Any] = {}
1372-
if rel_schema and prefetch_relationships:
1369+
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
1370+
if rel_schema and should_fetch_relationship:
13731371
peer_schema = self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
13741372
peer_node = InfrahubNodeSync(client=self._client, schema=peer_schema, branch=self._branch)
13751373
peer_data = peer_node.generate_query_data_node(include=include, exclude=exclude, property=property)

tests/unit/sdk/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ async def client() -> InfrahubClient:
3333
return InfrahubClient(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))
3434

3535

36+
@pytest.fixture
37+
async def client_sync() -> InfrahubClientSync:
38+
return InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))
39+
40+
3641
@pytest.fixture
3742
async def clients() -> BothClients:
3843
both = BothClients(
@@ -2641,3 +2646,20 @@ async def mock_query_tasks_05(httpx_mock: HTTPXMock) -> HTTPXMock:
26412646
is_reusable=True,
26422647
)
26432648
return httpx_mock
2649+
2650+
2651+
async def set_builtin_tag_schema_cache(client) -> None:
2652+
# Set tag schema in cache to avoid needed to request the server.
2653+
builtin_tag_schema = {
2654+
"version": "1.0",
2655+
"nodes": [
2656+
{
2657+
"name": "Tag",
2658+
"namespace": "Builtin",
2659+
"default_filter": "name__value",
2660+
"display_label": "name__value",
2661+
"branch": "aware",
2662+
}
2663+
],
2664+
}
2665+
client.schema.set_cache(builtin_tag_schema)

tests/unit/sdk/test_node.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from infrahub_sdk.node.constants import SAFE_VALUE
1818
from infrahub_sdk.node.related_node import RelatedNode, RelatedNodeSync
1919
from infrahub_sdk.schema import GenericSchema, NodeSchemaAPI
20+
from tests.unit.sdk.conftest import set_builtin_tag_schema_cache
2021

2122
if TYPE_CHECKING:
2223
from infrahub_sdk.client import InfrahubClient, InfrahubClientSync
@@ -1055,12 +1056,19 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client
10551056

10561057

10571058
@pytest.mark.parametrize("client_type", client_types)
1058-
async def test_query_data_include_property(client, location_schema: NodeSchemaAPI, client_type) -> None:
1059+
async def test_query_data_include_property(
1060+
client,
1061+
client_sync,
1062+
location_schema: NodeSchemaAPI,
1063+
client_type,
1064+
) -> None:
10591065
if client_type == "standard":
1066+
await set_builtin_tag_schema_cache(client)
10601067
node = InfrahubNode(client=client, schema=location_schema)
10611068
data = await node.generate_query_data(include=["tags"], property=True)
10621069
else:
1063-
node = InfrahubNodeSync(client=client, schema=location_schema)
1070+
await set_builtin_tag_schema_cache(client_sync)
1071+
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
10641072
data = node.generate_query_data(include=["tags"], property=True)
10651073

10661074
assert data == {
@@ -1178,12 +1186,19 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP
11781186

11791187

11801188
@pytest.mark.parametrize("client_type", client_types)
1181-
async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type) -> None:
1189+
async def test_query_data_include(
1190+
client,
1191+
client_sync,
1192+
location_schema: NodeSchemaAPI,
1193+
client_type,
1194+
) -> None:
11821195
if client_type == "standard":
1196+
await set_builtin_tag_schema_cache(client)
11831197
node = InfrahubNode(client=client, schema=location_schema)
11841198
data = await node.generate_query_data(include=["tags"])
11851199
else:
1186-
node = InfrahubNodeSync(client=client, schema=location_schema)
1200+
await set_builtin_tag_schema_cache(client_sync)
1201+
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
11871202
data = node.generate_query_data(include=["tags"])
11881203

11891204
assert data == {

0 commit comments

Comments
 (0)