Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/6882.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix value lookup using a flat notation like `foo__bar__value` with relationships of cardinality one
90 changes: 81 additions & 9 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError
from ..graphql import Mutation, Query
from ..schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind
from ..utils import compare_lists, generate_short_id, get_flat_value
from ..utils import compare_lists, generate_short_id
from .attribute import Attribute
from .constants import (
ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE,
Expand Down Expand Up @@ -418,14 +418,6 @@ def generate_query_data_init(

return data

def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = get_flat_value(self, key=value)

return result

def __hash__(self) -> int:
return hash(self.id)

Expand Down Expand Up @@ -1036,6 +1028,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode:

raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")

async def get_flat_value(self, key: str, separator: str = "__") -> Any:
"""Query recursively a value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(self, key)

first, remaining = key.split(separator, maxsplit=1)

if first in self._schema.attribute_names:
attr = getattr(self, first)
for part in remaining.split(separator):
attr = getattr(attr, part)
return attr

try:
rel = self._schema.get_relationship(name=first)
except ValueError as exc:
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc

if rel.cardinality != RelationshipCardinality.ONE:
raise ValueError(
f"Can only look up flat value for relationships of cardinality {RelationshipCardinality.ONE.value}"
)

related_node: RelatedNode = getattr(self, first)
await related_node.fetch()
return await related_node.peer.get_flat_value(key=remaining, separator=separator)

async def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = await self.get_flat_value(key=value)

return result

def __dir__(self) -> Iterable[str]:
base = list(super().__dir__())
return sorted(
Expand Down Expand Up @@ -1622,6 +1654,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode | RelatedNodeSync:

raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")

def get_flat_value(self, key: str, separator: str = "__") -> Any:
"""Query recursively a value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(self, key)

first, remaining = key.split(separator, maxsplit=1)

if first in self._schema.attribute_names:
attr = getattr(self, first)
for part in remaining.split(separator):
attr = getattr(attr, part)
return attr

try:
rel = self._schema.get_relationship(name=first)
except ValueError as exc:
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc

if rel.cardinality != RelationshipCardinality.ONE:
raise ValueError(
f"Can only look up flat value for relationships of cardinality {RelationshipCardinality.ONE.value}"
)

related_node: RelatedNodeSync = getattr(self, first)
related_node.fetch()
return related_node.peer.get_flat_value(key=remaining, separator=separator)

def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = self.get_flat_value(key=value)

return result

def __dir__(self) -> Iterable[str]:
base = list(super().__dir__())
return sorted(
Expand Down
2 changes: 0 additions & 2 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def is_resource_pool(self) -> bool: ...

def get_raw_graphql_data(self) -> dict | None: ...

def extract(self, params: dict[str, str]) -> dict[str, Any]: ...


@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
Expand Down
17 changes: 0 additions & 17 deletions infrahub_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,6 @@ def str_to_bool(value: str) -> bool:
raise ValueError(f"{value} can not be converted into a boolean") from exc


def get_flat_value(obj: Any, key: str, separator: str = "__") -> Any:
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(obj, key)

first_part, remaining_part = key.split(separator, maxsplit=1)
sub_obj = getattr(obj, first_part)
if not sub_obj:
return None
return get_flat_value(obj=sub_obj, key=remaining_part, separator=separator)


def generate_request_filename(request: httpx.Request) -> str:
"""Return a filename for a request sent to the Infrahub API

Expand Down
55 changes: 42 additions & 13 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,23 +1947,52 @@ async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_t


@pytest.mark.parametrize("client_type", client_types)
async def test_node_extract(client, location_schema, location_data01, client_type):
async def test_get_flat_value(
httpx_mock: HTTPXMock, mock_schema_query_01, clients, location_schema, location_data01, client_type
):
httpx_mock.add_response(
method="POST",
json={"data": {"BuiltinTag": {"count": 1, "edges": [location_data01["node"]["primary_tag"]]}}},
match_headers={"X-Infrahub-Tracker": "query-builtintag-page1"},
is_reusable=True,
)

if client_type == "standard":
node = InfrahubNode(client=client, schema=location_schema, data=location_data01)
tag = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
assert await tag.get_flat_value(key="name__value") == "DFW"
assert await tag.get_flat_value(key="primary_tag__display_label") == "red"
assert await tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"

with pytest.raises(ValueError, match="Can only look up flat value for relationships of cardinality one"):
assert await tag.get_flat_value(key="tags__display_label") == "red"
else:
node = InfrahubNodeSync(client=client, schema=location_schema, data=location_data01)
tag = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
assert tag.get_flat_value(key="name__value") == "DFW"
assert tag.get_flat_value(key="primary_tag__display_label") == "red"
assert tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"

params = {
"identifier": "id",
"name": "name__value",
"description": "description__value",
}
with pytest.raises(ValueError, match="Can only look up flat value for relationships of cardinality one"):
assert tag.get_flat_value(key="tags__display_label") == "red"

assert node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}

@pytest.mark.parametrize("client_type", client_types)
async def test_node_extract(clients, location_schema, location_data01, client_type):
params = {"identifier": "id", "name": "name__value", "description": "description__value"}
if client_type == "standard":
node = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
assert await node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}

else:
node = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
assert node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}


@pytest.mark.parametrize("client_type", client_types)
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/sdk/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from graphql import parse
from whenever import Instant

from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.utils import (
base16decode,
base16encode,
Expand All @@ -19,7 +18,6 @@
duplicates,
extract_fields,
generate_short_id,
get_flat_value,
is_valid_url,
is_valid_uuid,
str_to_bool,
Expand Down Expand Up @@ -143,13 +141,6 @@ def test_base16():
assert base16decode(base16encode(1412823931503067241)) == 1412823931503067241


def test_get_flat_value(client, tag_schema, tag_green_data):
tag = InfrahubNode(client=client, schema=tag_schema, data=tag_green_data)
assert get_flat_value(obj=tag, key="name__value") == "green"
assert get_flat_value(obj=tag, key="name__source__display_label") == "CRM"
assert get_flat_value(obj=tag, key="name.source.display_label", separator=".") == "CRM"


def test_dict_hash():
assert dict_hash({"a": 1, "b": 2}) == "608de49a4600dbb5b173492759792e4a"
assert dict_hash({"b": 2, "a": 1}) == "608de49a4600dbb5b173492759792e4a"
Expand Down