Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/6882.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix value lookup using a flat notation like `foo__bar__value` with relationships of cardinality one
90 changes: 81 additions & 9 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError
from ..graphql import Mutation, Query
from ..schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind
from ..utils import compare_lists, generate_short_id, get_flat_value
from ..utils import compare_lists, generate_short_id
from .attribute import Attribute
from .constants import (
ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE,
Expand Down Expand Up @@ -418,14 +418,6 @@ def generate_query_data_init(

return data

def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = get_flat_value(self, key=value)

return result

def __hash__(self) -> int:
return hash(self.id)

Expand Down Expand Up @@ -1036,6 +1028,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode:

raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")

async def get_flat_value(self, key: str, separator: str = "__") -> Any:
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(self, key)

first, remaining = key.split(separator, maxsplit=1)

if first in self._schema.attribute_names:
attr = getattr(self, first)
for part in remaining.split(separator):
attr = getattr(attr, part)
return attr

try:
rel = self._schema.get_relationship(name=first)
except ValueError as exc:
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc

if rel.cardinality != RelationshipCardinality.ONE:
raise ValueError(
f"Unable to lookup flat value for relationship of cardinality {RelationshipCardinality.MANY}"
)

related_node: RelatedNode = getattr(self, first)
await related_node.fetch()
return await related_node.peer.get_flat_value(key=remaining, separator=separator)

async def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = await self.get_flat_value(key=value)

return result

def __dir__(self) -> Iterable[str]:
base = list(super().__dir__())
return sorted(
Expand Down Expand Up @@ -1622,6 +1654,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode | RelatedNodeSync:

raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")

def get_flat_value(self, key: str, separator: str = "__") -> Any:
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(self, key)

first, remaining = key.split(separator, maxsplit=1)

if first in self._schema.attribute_names:
attr = getattr(self, first)
for part in remaining.split(separator):
attr = getattr(attr, part)
return attr

try:
rel = self._schema.get_relationship(name=first)
except ValueError as exc:
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc

if rel.cardinality != RelationshipCardinality.ONE:
raise ValueError(
f"Unable to lookup flat value for relationship of cardinality {RelationshipCardinality.MANY}"
)

related_node: RelatedNodeSync = getattr(self, first)
related_node.fetch()
return related_node.peer.get_flat_value(key=remaining, separator=separator)

def extract(self, params: dict[str, str]) -> dict[str, Any]:
"""Extract some datapoints defined in a flat notation."""
result: dict[str, Any] = {}
for key, value in params.items():
result[key] = self.get_flat_value(key=value)

return result

def __dir__(self) -> Iterable[str]:
base = list(super().__dir__())
return sorted(
Expand Down
2 changes: 0 additions & 2 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def is_resource_pool(self) -> bool: ...

def get_raw_graphql_data(self) -> dict | None: ...

def extract(self, params: dict[str, str]) -> dict[str, Any]: ...


@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
Expand Down
17 changes: 0 additions & 17 deletions infrahub_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,6 @@ def str_to_bool(value: str) -> bool:
raise ValueError(f"{value} can not be converted into a boolean") from exc


def get_flat_value(obj: Any, key: str, separator: str = "__") -> Any:
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects

Examples:
name__value
module.object.value
"""
if separator not in key:
return getattr(obj, key)

first_part, remaining_part = key.split(separator, maxsplit=1)
sub_obj = getattr(obj, first_part)
if not sub_obj:
return None
return get_flat_value(obj=sub_obj, key=remaining_part, separator=separator)


def generate_request_filename(request: httpx.Request) -> str:
"""Return a filename for a request sent to the Infrahub API

Expand Down
55 changes: 42 additions & 13 deletions tests/unit/sdk/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,23 +1947,52 @@ async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_t


@pytest.mark.parametrize("client_type", client_types)
async def test_node_extract(client, location_schema, location_data01, client_type):
async def test_get_flat_value(
httpx_mock: HTTPXMock, mock_schema_query_01, clients, location_schema, location_data01, client_type
):
httpx_mock.add_response(
method="POST",
json={"data": {"BuiltinTag": {"count": 1, "edges": [location_data01["node"]["primary_tag"]]}}},
match_headers={"X-Infrahub-Tracker": "query-builtintag-page1"},
is_reusable=True,
)

if client_type == "standard":
node = InfrahubNode(client=client, schema=location_schema, data=location_data01)
tag = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
assert await tag.get_flat_value(key="name__value") == "DFW"
assert await tag.get_flat_value(key="primary_tag__display_label") == "red"
assert await tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"

with pytest.raises(ValueError, match="Unable to lookup flat value for relationship of cardinality"):
assert await tag.get_flat_value(key="tags__display_label") == "red"
else:
node = InfrahubNodeSync(client=client, schema=location_schema, data=location_data01)
tag = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
assert tag.get_flat_value(key="name__value") == "DFW"
assert tag.get_flat_value(key="primary_tag__display_label") == "red"
assert tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"

params = {
"identifier": "id",
"name": "name__value",
"description": "description__value",
}
with pytest.raises(ValueError, match="Unable to lookup flat value for relationship of cardinality"):
assert tag.get_flat_value(key="tags__display_label") == "red"

assert node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}

@pytest.mark.parametrize("client_type", client_types)
async def test_node_extract(clients, location_schema, location_data01, client_type):
params = {"identifier": "id", "name": "name__value", "description": "description__value"}
if client_type == "standard":
node = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
assert await node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}

else:
node = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
assert node.extract(params=params) == {
"description": None,
"identifier": "llllllll-llll-llll-llll-llllllllllll",
"name": "DFW",
}


@pytest.mark.parametrize("client_type", client_types)
Expand Down
9 changes: 0 additions & 9 deletions tests/unit/sdk/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from graphql import parse
from whenever import Instant

from infrahub_sdk.node import InfrahubNode
from infrahub_sdk.utils import (
base16decode,
base16encode,
Expand All @@ -19,7 +18,6 @@
duplicates,
extract_fields,
generate_short_id,
get_flat_value,
is_valid_url,
is_valid_uuid,
str_to_bool,
Expand Down Expand Up @@ -143,13 +141,6 @@ def test_base16():
assert base16decode(base16encode(1412823931503067241)) == 1412823931503067241


def test_get_flat_value(client, tag_schema, tag_green_data):
tag = InfrahubNode(client=client, schema=tag_schema, data=tag_green_data)
assert get_flat_value(obj=tag, key="name__value") == "green"
assert get_flat_value(obj=tag, key="name__source__display_label") == "CRM"
assert get_flat_value(obj=tag, key="name.source.display_label", separator=".") == "CRM"


def test_dict_hash():
assert dict_hash({"a": 1, "b": 2}) == "608de49a4600dbb5b173492759792e4a"
assert dict_hash({"b": 2, "a": 1}) == "608de49a4600dbb5b173492759792e4a"
Expand Down