diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 16c1c73a..b301f3b0 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -540,6 +540,7 @@ async def _process_nodes_and_relationships( schema_kind: str, branch: str, prefetch_relationships: bool, + include: list[str] | None, timeout: int | None = None, ) -> ProcessRelationsNode: """Processes InfrahubNode and their Relationships from the GraphQL query response. @@ -564,9 +565,12 @@ async def _process_nodes_and_relationships( node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item, timeout=timeout) nodes.append(node) - if prefetch_relationships: + if prefetch_relationships or (include and any(rel in include for rel in node._relationships)): await node._process_relationships( - node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout + node_data=item, + branch=branch, + related_nodes=related_nodes, + timeout=timeout, ) return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes) @@ -816,6 +820,7 @@ async def process_page(page_offset: int, page_number: int) -> tuple[dict, Proces branch=branch, prefetch_relationships=prefetch_relationships, timeout=timeout, + include=include, ) return response, process_result @@ -1831,6 +1836,7 @@ def _process_nodes_and_relationships( schema_kind: str, branch: str, prefetch_relationships: bool, + include: list[str] | None, timeout: int | None = None, ) -> ProcessRelationsNodeSync: """Processes InfrahubNodeSync and their Relationships from the GraphQL query response. @@ -1855,7 +1861,7 @@ def _process_nodes_and_relationships( node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item, timeout=timeout) nodes.append(node) - if prefetch_relationships: + if prefetch_relationships or (include and any(rel in include for rel in node._relationships)): node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout) return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes) @@ -1980,6 +1986,7 @@ def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelat branch=branch, prefetch_relationships=prefetch_relationships, timeout=timeout, + include=include, ) return response, process_result diff --git a/infrahub_sdk/node/node.py b/infrahub_sdk/node/node.py index 242281b5..84d774e6 100644 --- a/infrahub_sdk/node/node.py +++ b/infrahub_sdk/node/node.py @@ -742,12 +742,11 @@ async def generate_query_data_node( continue peer_data: dict[str, Any] = {} - if rel_schema and prefetch_relationships: + should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include) + if rel_schema and should_fetch_relationship: peer_schema = await self._client.schema.get(kind=rel_schema.peer, branch=self._branch) peer_node = InfrahubNode(client=self._client, schema=peer_schema, branch=self._branch) peer_data = await peer_node.generate_query_data_node( - include=include, - exclude=exclude, property=property, ) @@ -886,7 +885,11 @@ async def update( await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) async def _process_relationships( - self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNode], timeout: int | None = None + self, + node_data: dict[str, Any], + branch: str, + related_nodes: list[InfrahubNode], + timeout: int | None = None, ) -> None: """Processes the Relationships of a InfrahubNode and add Related Nodes to a list. @@ -1363,7 +1366,8 @@ def generate_query_data_node( continue peer_data: dict[str, Any] = {} - if rel_schema and prefetch_relationships: + should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include) + if rel_schema and should_fetch_relationship: peer_schema = self._client.schema.get(kind=rel_schema.peer, branch=self._branch) peer_node = InfrahubNodeSync(client=self._client, schema=peer_schema, branch=self._branch) peer_data = peer_node.generate_query_data_node(include=include, exclude=exclude, property=property) diff --git a/tests/integration/test_node.py b/tests/integration/test_node.py index 3716dc1a..652e7203 100644 --- a/tests/integration/test_node.py +++ b/tests/integration/test_node.py @@ -85,6 +85,40 @@ async def test_node_create_with_relationships_using_related_node( assert node_after.owner.peer.id == person_joe.id assert node_after.owner.peer.typename == "TestingPerson" + async def test_node_filters_include( + self, + default_branch: str, + client: InfrahubClient, + initial_schema: None, + manufacturer_mercedes, + person_joe, + tag_red, + ) -> None: + car = await client.create( + kind=TESTING_CAR, + name="Tiguan2", + color="Black", + manufacturer=manufacturer_mercedes, + owner=person_joe, + tags=[tag_red], + ) + await car.save(allow_upsert=True) + assert car.id is not None + + # Clear store, as when we call `owner.peer`, we actually rely on the peer having being stored in store. + client.store._branches = {} + node_after = await client.get(kind=TESTING_CAR, id=car.id) + + with pytest.raises(NodeNotFoundError, match=f"Unable to find the node '{person_joe.id}' in the store"): + _ = node_after.owner.peer + + assert len(node_after.tags.peers) == 0 + + # Test both one and many relationships + node_after = await client.get(kind=TESTING_CAR, id=car.id, include=["tags", "owner"]) + assert [tag.id for tag in node_after.tags.peers] == [tag_red.id] + assert node_after.owner.peer.id == person_joe.id, f"{person_joe.id=}" + async def test_node_update_with_original_data( self, default_branch: str, diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index 6943900e..5f0d7c2a 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -33,6 +33,11 @@ async def client() -> InfrahubClient: return InfrahubClient(config=Config(address="http://mock", insert_tracker=True, pagination_size=3)) +@pytest.fixture +async def client_sync() -> InfrahubClientSync: + return InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True, pagination_size=3)) + + @pytest.fixture async def clients() -> BothClients: both = BothClients( diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index 5dd7730b..3cae8b56 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -53,6 +53,23 @@ ] +async def set_builtin_tag_schema_cache(client) -> None: + # Set tag schema in cache to avoid needed to request the server. + builtin_tag_schema = { + "version": "1.0", + "nodes": [ + { + "name": "Tag", + "namespace": "Builtin", + "default_filter": "name__value", + "display_label": "name__value", + "branch": "aware", + } + ], + } + client.schema.set_cache(builtin_tag_schema) + + async def test_method_sanity() -> None: """Validate that there is at least one public method and that both clients look the same.""" assert async_node_methods @@ -1055,12 +1072,19 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_include_property(client, location_schema: NodeSchemaAPI, client_type) -> None: +async def test_query_data_include_property( + client, + client_sync, + location_schema: NodeSchemaAPI, + client_type, +) -> None: if client_type == "standard": + await set_builtin_tag_schema_cache(client) node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(include=["tags"], property=True) else: - node = InfrahubNodeSync(client=client, schema=location_schema) + await set_builtin_tag_schema_cache(client_sync) + node = InfrahubNodeSync(client=client_sync, schema=location_schema) data = node.generate_query_data(include=["tags"], property=True) assert data == { @@ -1178,12 +1202,19 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type) -> None: +async def test_query_data_include( + client, + client_sync, + location_schema: NodeSchemaAPI, + client_type, +) -> None: if client_type == "standard": + await set_builtin_tag_schema_cache(client) node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(include=["tags"]) else: - node = InfrahubNodeSync(client=client, schema=location_schema) + await set_builtin_tag_schema_cache(client_sync) + node = InfrahubNodeSync(client=client_sync, schema=location_schema) data = node.generate_query_data(include=["tags"]) assert data == {