Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ async def _process_nodes_and_relationships(
schema_kind: str,
branch: str,
prefetch_relationships: bool,
include: list[str] | None,
timeout: int | None = None,
) -> ProcessRelationsNode:
"""Processes InfrahubNode and their Relationships from the GraphQL query response.
Expand All @@ -557,16 +558,22 @@ async def _process_nodes_and_relationships(
- 'related_nodes': A list of InfrahubNode objects representing the related nodes
"""

# Ideally, include and relationships wouldn't be parameters of this method, they should only
# be used to build the request for the server, and this method would build node according to the response.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this comment and I think you can do this differently and then remove the comment

what about instead of passing include in here, you update the prefetch_relationships parameter to be a do_process_relationships Boolean flag and then update the two places where _process_nodes_and_relationships is called to set do_process_relationships based on prefetch_relationships and if include is not empty (or maybe if include actually includes the name of a relationship)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is about the fact that we should rely on these parameters only to fetch the data from the server and not while building the nodes client side.

The current logic is:

if prefetch_relationships / include has a rel_name:
   fetch_rel_from_server

if prefetch_relationships / include has a rel_name:
    process_relationships() # build peers nodes

The comment means that we should get rid of the second if condition, ie, while building the nodes we should just build it ouf of the payload that we previously fetched. Does this make sense? I tried to do so but it led to some complicated bugs that should be treated out of this PR if we want to clean that.

I still made a minor change following up your comment so that include checks on whether it contains a relationship name.

Using a do_prefetch_relationships as you suggested might not fit well here because we have access to node relationships once it is built inside the _process_nodes_and_relationships method, and we need relationships names to check whether include contains some.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just suggesting making the below small changes. I didn't realize that this comment was related to existing functionality. This sounds like something that deserves its own issue in GH/Jira instead of a comment that will be ignored / forgotten

diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py
index f9a380b..cae970a 100644
--- a/infrahub_sdk/client.py
+++ b/infrahub_sdk/client.py
@@ -539,8 +539,7 @@ class InfrahubClient(BaseClient):
         response: dict[str, Any],
         schema_kind: str,
         branch: str,
-        prefetch_relationships: bool,
-        include: list[str] | None,
+        do_process_relationships: bool,
         timeout: int | None = None,
     ) -> ProcessRelationsNode:
         """Processes InfrahubNode and their Relationships from the GraphQL query response.
@@ -549,7 +548,7 @@ class InfrahubClient(BaseClient):
             response (dict[str, Any]): The response from the GraphQL query.
             schema_kind (str): The kind of schema being queried.
             branch (str): The branch name.
-            prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data.
+            do_process_relationships (bool): Flag to indicate whether to fetch relationship data.
             timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
 
         Returns:
@@ -568,7 +567,7 @@ class InfrahubClient(BaseClient):
             node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
             nodes.append(node)
 
-            if prefetch_relationships or include is not None:
+            if do_process_relationships:
                 await node._process_relationships(
                     node_data=item,
                     branch=branch,
@@ -821,9 +820,8 @@ class InfrahubClient(BaseClient):
                 response=response,
                 schema_kind=schema.kind,
                 branch=branch,
-                prefetch_relationships=prefetch_relationships,
+                do_process_relationships=prefetch_relationships or include is not None,
                 timeout=timeout,
-                include=include,
             )
             return response, process_result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood that this was the change you suggested, and I think we cannot do that because
do_process_relationships=prefetch_relationships or include is not None is missing checking relationships names as you nicely pointed out, and we need the node being built to do that, cf my previous comment

Using a do_prefetch_relationships as you suggested might not fit well here because we have access to node relationships once it is built inside the _process_nodes_and_relationships method, and we need relationships names to check whether include contains some.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created IHS-161


nodes: list[InfrahubNode] = []
related_nodes: list[InfrahubNode] = []

for item in response.get(schema_kind, {}).get("edges", []):
node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
nodes.append(node)

if prefetch_relationships:
if prefetch_relationships or include is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should have an extra check here. I.e. if "include" just consists of a list of attribute names we don't really need to run the code for process_relationships?

await node._process_relationships(
node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout
node_data=item,
branch=branch,
related_nodes=related_nodes,
timeout=timeout,
)

return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes)
Expand Down Expand Up @@ -816,6 +823,7 @@ async def process_page(page_offset: int, page_number: int) -> tuple[dict, Proces
branch=branch,
prefetch_relationships=prefetch_relationships,
timeout=timeout,
include=include,
)
return response, process_result

Expand Down Expand Up @@ -1831,6 +1839,7 @@ def _process_nodes_and_relationships(
schema_kind: str,
branch: str,
prefetch_relationships: bool,
include: list[str] | None,
timeout: int | None = None,
) -> ProcessRelationsNodeSync:
"""Processes InfrahubNodeSync and their Relationships from the GraphQL query response.
Expand All @@ -1855,7 +1864,7 @@ def _process_nodes_and_relationships(
node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item, timeout=timeout)
nodes.append(node)

if prefetch_relationships:
if prefetch_relationships or include is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout)

return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes)
Expand Down Expand Up @@ -1980,6 +1989,7 @@ def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelat
branch=branch,
prefetch_relationships=prefetch_relationships,
timeout=timeout,
include=include,
)
return response, process_result

Expand Down
14 changes: 9 additions & 5 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,12 +742,11 @@ async def generate_query_data_node(
continue

peer_data: dict[str, Any] = {}
if rel_schema and prefetch_relationships:
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was something like this I was thinking about in the comments above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had actually addressed these changes following Aaron's comment but didn't push them

if rel_schema and should_fetch_relationship:
peer_schema = await self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
peer_node = InfrahubNode(client=self._client, schema=peer_schema, branch=self._branch)
peer_data = await peer_node.generate_query_data_node(
include=include,
exclude=exclude,
property=property,
)

Comment on lines 744 to 752
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Async nested prefetch should propagate include/exclude (filtered) to match sync semantics and avoid over/under-fetch.

Currently only property is passed; this diverges from sync and can omit requested nested fields.

Apply:

-            peer_data = await peer_node.generate_query_data_node(
-                property=property,
-            )
+            # Propagate only relationship names relevant to the peer schema
+            peer_rel_names = getattr(peer_schema, "relationship_names", [r.name for r in peer_schema.relationships])
+            child_include = [n for n in (include or []) if n in peer_rel_names] or None
+            child_exclude = [n for n in (exclude or []) if n in peer_rel_names] or None
+            peer_data = await peer_node.generate_query_data_node(
+                include=child_include,
+                exclude=child_exclude,
+                property=property,
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
peer_data: dict[str, Any] = {}
if rel_schema and prefetch_relationships:
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
if rel_schema and should_fetch_relationship:
peer_schema = await self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
peer_node = InfrahubNode(client=self._client, schema=peer_schema, branch=self._branch)
peer_data = await peer_node.generate_query_data_node(
include=include,
exclude=exclude,
property=property,
)
peer_data: dict[str, Any] = {}
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
if rel_schema and should_fetch_relationship:
peer_schema = await self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
peer_node = InfrahubNode(client=self._client, schema=peer_schema, branch=self._branch)
# Propagate only relationship names relevant to the peer schema
peer_rel_names = getattr(peer_schema, "relationship_names", [r.name for r in peer_schema.relationships])
child_include = [n for n in (include or []) if n in peer_rel_names] or None
child_exclude = [n for n in (exclude or []) if n in peer_rel_names] or None
peer_data = await peer_node.generate_query_data_node(
include=child_include,
exclude=child_exclude,
property=property,
)

Expand Down Expand Up @@ -886,7 +885,11 @@ async def update(
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)

async def _process_relationships(
self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNode], timeout: int | None = None
self,
node_data: dict[str, Any],
branch: str,
related_nodes: list[InfrahubNode],
timeout: int | None = None,
) -> None:
"""Processes the Relationships of a InfrahubNode and add Related Nodes to a list.

Expand Down Expand Up @@ -1363,7 +1366,8 @@ def generate_query_data_node(
continue

peer_data: dict[str, Any] = {}
if rel_schema and prefetch_relationships:
should_fetch_relationship = prefetch_relationships or (include is not None and rel_name in include)
if rel_schema and should_fetch_relationship:
peer_schema = self._client.schema.get(kind=rel_schema.peer, branch=self._branch)
peer_node = InfrahubNodeSync(client=self._client, schema=peer_schema, branch=self._branch)
peer_data = peer_node.generate_query_data_node(include=include, exclude=exclude, property=property)
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,40 @@ async def test_node_create_with_relationships_using_related_node(
assert node_after.owner.peer.id == person_joe.id
assert node_after.owner.peer.typename == "TestingPerson"

async def test_node_filters_include(
self,
default_branch: str,
client: InfrahubClient,
initial_schema: None,
manufacturer_mercedes,
person_joe,
tag_red,
) -> None:
car = await client.create(
kind=TESTING_CAR,
name="Tiguan2",
color="Black",
manufacturer=manufacturer_mercedes,
owner=person_joe,
tags=[tag_red],
)
await car.save(allow_upsert=True)
assert car.id is not None

# Clear store, as when we call `owner.peer`, we actually rely on the peer having being stored in store.
client.store._branches = {}
node_after = await client.get(kind=TESTING_CAR, id=car.id)

with pytest.raises(NodeNotFoundError, match=f"Unable to find the node '{person_joe.id}' in the store"):
_ = node_after.owner.peer

assert len(node_after.tags.peers) == 0

# Test both one and many relationships
node_after = await client.get(kind=TESTING_CAR, id=car.id, include=["tags", "owner"])
assert [tag.id for tag in node_after.tags.peers] == [tag_red.id]
assert node_after.owner.peer.id == person_joe.id, f"{person_joe.id=}"

async def test_node_update_with_original_data(
self,
default_branch: str,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ async def client() -> InfrahubClient:
return InfrahubClient(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))


@pytest.fixture
async def client_sync() -> InfrahubClientSync:
return InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True, pagination_size=3))


@pytest.fixture
async def clients() -> BothClients:
both = BothClients(
Expand Down Expand Up @@ -2641,3 +2646,20 @@ async def mock_query_tasks_05(httpx_mock: HTTPXMock) -> HTTPXMock:
is_reusable=True,
)
return httpx_mock


async def set_builtin_tag_schema_cache(client) -> None:
# Set tag schema in cache to avoid needed to request the server.
builtin_tag_schema = {
"version": "1.0",
"nodes": [
{
"name": "Tag",
"namespace": "Builtin",
"default_filter": "name__value",
"display_label": "name__value",
"branch": "aware",
}
],
}
client.schema.set_cache(builtin_tag_schema)
23 changes: 19 additions & 4 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from infrahub_sdk.node.constants import SAFE_VALUE
from infrahub_sdk.node.related_node import RelatedNode, RelatedNodeSync
from infrahub_sdk.schema import GenericSchema, NodeSchemaAPI
from tests.unit.sdk.conftest import set_builtin_tag_schema_cache

if TYPE_CHECKING:
from infrahub_sdk.client import InfrahubClient, InfrahubClientSync
Expand Down Expand Up @@ -1055,12 +1056,19 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client


@pytest.mark.parametrize("client_type", client_types)
async def test_query_data_include_property(client, location_schema: NodeSchemaAPI, client_type) -> None:
async def test_query_data_include_property(
client,
client_sync,
location_schema: NodeSchemaAPI,
client_type,
) -> None:
if client_type == "standard":
await set_builtin_tag_schema_cache(client)
node = InfrahubNode(client=client, schema=location_schema)
data = await node.generate_query_data(include=["tags"], property=True)
else:
node = InfrahubNodeSync(client=client, schema=location_schema)
await set_builtin_tag_schema_cache(client_sync)
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
data = node.generate_query_data(include=["tags"], property=True)

assert data == {
Expand Down Expand Up @@ -1178,12 +1186,19 @@ async def test_query_data_include_property(client, location_schema: NodeSchemaAP


@pytest.mark.parametrize("client_type", client_types)
async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type) -> None:
async def test_query_data_include(
client,
client_sync,
location_schema: NodeSchemaAPI,
client_type,
) -> None:
if client_type == "standard":
await set_builtin_tag_schema_cache(client)
node = InfrahubNode(client=client, schema=location_schema)
data = await node.generate_query_data(include=["tags"])
else:
node = InfrahubNodeSync(client=client, schema=location_schema)
await set_builtin_tag_schema_cache(client_sync)
node = InfrahubNodeSync(client=client_sync, schema=location_schema)
data = node.generate_query_data(include=["tags"])

assert data == {
Expand Down