Skip to content

Commit 5f5f766

Browse files
committed
Rework get_flat_value to fix node relationships
1 parent 8d89883 commit 5f5f766

File tree

4 files changed

+123
-48
lines changed

4 files changed

+123
-48
lines changed

infrahub_sdk/node/node.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError
99
from ..graphql import Mutation, Query
1010
from ..schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind
11-
from ..utils import compare_lists, generate_short_id, get_flat_value
11+
from ..utils import compare_lists, generate_short_id
1212
from .attribute import Attribute
1313
from .constants import (
1414
ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE,
@@ -418,14 +418,6 @@ def generate_query_data_init(
418418

419419
return data
420420

421-
def extract(self, params: dict[str, str]) -> dict[str, Any]:
422-
"""Extract some datapoints defined in a flat notation."""
423-
result: dict[str, Any] = {}
424-
for key, value in params.items():
425-
result[key] = get_flat_value(self, key=value)
426-
427-
return result
428-
429421
def __hash__(self) -> int:
430422
return hash(self.id)
431423

@@ -1036,6 +1028,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode:
10361028

10371029
raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")
10381030

1031+
async def get_flat_value(self, key: str, separator: str = "__") -> Any:
1032+
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects
1033+
1034+
Examples:
1035+
name__value
1036+
module.object.value
1037+
"""
1038+
if separator not in key:
1039+
return getattr(self, key)
1040+
1041+
first, remaining = key.split(separator, maxsplit=1)
1042+
1043+
if first in self._schema.attribute_names:
1044+
attr = getattr(self, first)
1045+
for part in remaining.split(separator):
1046+
attr = getattr(attr, part)
1047+
return attr
1048+
1049+
try:
1050+
rel = self._schema.get_relationship(name=first)
1051+
except ValueError as exc:
1052+
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc
1053+
1054+
if rel.cardinality != RelationshipCardinality.ONE:
1055+
raise ValueError(
1056+
f"Unable to lookup flat value for relationship of cardinality {RelationshipCardinality.MANY}"
1057+
)
1058+
1059+
related_node: RelatedNode = getattr(self, first)
1060+
await related_node.fetch()
1061+
return await related_node.peer.get_flat_value(key=remaining, separator=separator)
1062+
1063+
async def extract(self, params: dict[str, str]) -> dict[str, Any]:
1064+
"""Extract some datapoints defined in a flat notation."""
1065+
result: dict[str, Any] = {}
1066+
for key, value in params.items():
1067+
result[key] = await self.get_flat_value(key=value)
1068+
1069+
return result
1070+
10391071
def __dir__(self) -> Iterable[str]:
10401072
base = list(super().__dir__())
10411073
return sorted(
@@ -1622,6 +1654,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode | RelatedNodeSync:
16221654

16231655
raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}")
16241656

1657+
def get_flat_value(self, key: str, separator: str = "__") -> Any:
1658+
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects
1659+
1660+
Examples:
1661+
name__value
1662+
module.object.value
1663+
"""
1664+
if separator not in key:
1665+
return getattr(self, key)
1666+
1667+
first, remaining = key.split(separator, maxsplit=1)
1668+
1669+
if first in self._schema.attribute_names:
1670+
attr = getattr(self, first)
1671+
for part in remaining.split(separator):
1672+
attr = getattr(attr, part)
1673+
return attr
1674+
1675+
try:
1676+
rel = self._schema.get_relationship(name=first)
1677+
except ValueError as exc:
1678+
raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc
1679+
1680+
if rel.cardinality != RelationshipCardinality.ONE:
1681+
raise ValueError(
1682+
f"Unable to lookup flat value for relationship of cardinality {RelationshipCardinality.MANY}"
1683+
)
1684+
1685+
related_node: RelatedNodeSync = getattr(self, first)
1686+
related_node.fetch()
1687+
return related_node.peer.get_flat_value(key=remaining, separator=separator)
1688+
1689+
def extract(self, params: dict[str, str]) -> dict[str, Any]:
1690+
"""Extract some datapoints defined in a flat notation."""
1691+
result: dict[str, Any] = {}
1692+
for key, value in params.items():
1693+
result[key] = self.get_flat_value(key=value)
1694+
1695+
return result
1696+
16251697
def __dir__(self) -> Iterable[str]:
16261698
base = list(super().__dir__())
16271699
return sorted(

infrahub_sdk/utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -190,23 +190,6 @@ def str_to_bool(value: str) -> bool:
190190
raise ValueError(f"{value} can not be converted into a boolean") from exc
191191

192192

193-
def get_flat_value(obj: Any, key: str, separator: str = "__") -> Any:
194-
"""Query recursively an value defined in a flat notation (string), on a hierarchy of objects
195-
196-
Examples:
197-
name__value
198-
module.object.value
199-
"""
200-
if separator not in key:
201-
return getattr(obj, key)
202-
203-
first_part, remaining_part = key.split(separator, maxsplit=1)
204-
sub_obj = getattr(obj, first_part)
205-
if not sub_obj:
206-
return None
207-
return get_flat_value(obj=sub_obj, key=remaining_part, separator=separator)
208-
209-
210193
def generate_request_filename(request: httpx.Request) -> str:
211194
"""Return a filename for a request sent to the Infrahub API
212195

tests/unit/sdk/test_node.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,23 +1947,52 @@ async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_t
19471947

19481948

19491949
@pytest.mark.parametrize("client_type", client_types)
1950-
async def test_node_extract(client, location_schema, location_data01, client_type):
1950+
async def test_get_flat_value(
1951+
httpx_mock: HTTPXMock, mock_schema_query_01, clients, location_schema, location_data01, client_type
1952+
):
1953+
httpx_mock.add_response(
1954+
method="POST",
1955+
json={"data": {"BuiltinTag": {"count": 1, "edges": [location_data01["node"]["primary_tag"]]}}},
1956+
match_headers={"X-Infrahub-Tracker": "query-builtintag-page1"},
1957+
is_reusable=True,
1958+
)
1959+
19511960
if client_type == "standard":
1952-
node = InfrahubNode(client=client, schema=location_schema, data=location_data01)
1961+
tag = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
1962+
assert await tag.get_flat_value(key="name__value") == "DFW"
1963+
assert await tag.get_flat_value(key="primary_tag__display_label") == "red"
1964+
assert await tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"
1965+
1966+
with pytest.raises(ValueError, match="Unable to lookup flat value for relationship of cardinality"):
1967+
assert await tag.get_flat_value(key="tags__display_label") == "red"
19531968
else:
1954-
node = InfrahubNodeSync(client=client, schema=location_schema, data=location_data01)
1969+
tag = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
1970+
assert tag.get_flat_value(key="name__value") == "DFW"
1971+
assert tag.get_flat_value(key="primary_tag__display_label") == "red"
1972+
assert tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red"
19551973

1956-
params = {
1957-
"identifier": "id",
1958-
"name": "name__value",
1959-
"description": "description__value",
1960-
}
1974+
with pytest.raises(ValueError, match="Unable to lookup flat value for relationship of cardinality"):
1975+
assert tag.get_flat_value(key="tags__display_label") == "red"
19611976

1962-
assert node.extract(params=params) == {
1963-
"description": None,
1964-
"identifier": "llllllll-llll-llll-llll-llllllllllll",
1965-
"name": "DFW",
1966-
}
1977+
1978+
@pytest.mark.parametrize("client_type", client_types)
1979+
async def test_node_extract(clients, location_schema, location_data01, client_type):
1980+
params = {"identifier": "id", "name": "name__value", "description": "description__value"}
1981+
if client_type == "standard":
1982+
node = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01)
1983+
assert await node.extract(params=params) == {
1984+
"description": None,
1985+
"identifier": "llllllll-llll-llll-llll-llllllllllll",
1986+
"name": "DFW",
1987+
}
1988+
1989+
else:
1990+
node = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01)
1991+
assert node.extract(params=params) == {
1992+
"description": None,
1993+
"identifier": "llllllll-llll-llll-llll-llllllllllll",
1994+
"name": "DFW",
1995+
}
19671996

19681997

19691998
@pytest.mark.parametrize("client_type", client_types)

tests/unit/sdk/test_utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from graphql import parse
77
from whenever import Instant
88

9-
from infrahub_sdk.node import InfrahubNode
109
from infrahub_sdk.utils import (
1110
base16decode,
1211
base16encode,
@@ -19,7 +18,6 @@
1918
duplicates,
2019
extract_fields,
2120
generate_short_id,
22-
get_flat_value,
2321
is_valid_url,
2422
is_valid_uuid,
2523
str_to_bool,
@@ -143,13 +141,6 @@ def test_base16():
143141
assert base16decode(base16encode(1412823931503067241)) == 1412823931503067241
144142

145143

146-
def test_get_flat_value(client, tag_schema, tag_green_data):
147-
tag = InfrahubNode(client=client, schema=tag_schema, data=tag_green_data)
148-
assert get_flat_value(obj=tag, key="name__value") == "green"
149-
assert get_flat_value(obj=tag, key="name__source__display_label") == "CRM"
150-
assert get_flat_value(obj=tag, key="name.source.display_label", separator=".") == "CRM"
151-
152-
153144
def test_dict_hash():
154145
assert dict_hash({"a": 1, "b": 2}) == "608de49a4600dbb5b173492759792e4a"
155146
assert dict_hash({"b": 2, "a": 1}) == "608de49a4600dbb5b173492759792e4a"

0 commit comments

Comments
 (0)