Skip to content

Commit a789f2b

Browse files
committed
Ensure kind matches when retrieving objects from the store
1 parent 17d4a45 commit a789f2b

File tree

5 files changed

+77
-13
lines changed

5 files changed

+77
-13
lines changed

infrahub_sdk/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def __str__(self) -> str:
8888
"""
8989

9090

91+
class NodeInvalidError(NodeNotFoundError):
92+
pass
93+
94+
9195
class ResourceNotDefinedError(Error):
9296
"""Raised when trying to access a resource that hasn't been defined."""
9397

infrahub_sdk/node.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,11 @@ def __repr__(self) -> str:
814814
def get_kind(self) -> str:
815815
return self._schema.kind
816816

817+
def get_all_kinds(self) -> list[str]:
818+
if hasattr(self._schema, "inherit_from"):
819+
return [self._schema.kind] + self._schema.inherit_from
820+
return [self._schema.kind]
821+
817822
def is_ip_prefix(self) -> bool:
818823
builtin_ipprefix_kind = "BuiltinIPPrefix"
819824
return self.get_kind() == builtin_ipprefix_kind or builtin_ipprefix_kind in self._schema.inherit_from # type: ignore[union-attr]

infrahub_sdk/protocols_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | N
160160

161161
def get_kind(self) -> str: ...
162162

163+
def get_all_kinds(self) -> list[str]: ...
164+
163165
def get_branch(self) -> str: ...
164166

165167
def is_ip_prefix(self) -> bool: ...

infrahub_sdk/store.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from typing import TYPE_CHECKING, Literal, overload
66

7-
from .exceptions import NodeNotFoundError
7+
from .exceptions import NodeInvalidError, NodeNotFoundError
88
from .node import parse_human_friendly_id
99

1010
if TYPE_CHECKING:
@@ -53,62 +53,110 @@ def get( # type: ignore[no-untyped-def]
5353
kind: type[SchemaType | SchemaTypeSync] | str | None = None,
5454
raise_when_missing: bool = True,
5555
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync | None:
56+
found_invalid = False
57+
58+
kind_name = get_schema_name(schema=kind)
59+
5660
try:
57-
return self._get_by_internal_id(key)
61+
return self._get_by_internal_id(key, kind=kind_name)
62+
except NodeInvalidError:
63+
found_invalid = True
5864
except NodeNotFoundError:
5965
pass
6066

6167
try:
62-
return self._get_by_id(key)
68+
return self._get_by_id(key, kind=kind_name)
69+
except NodeInvalidError:
70+
found_invalid = True
6371
except NodeNotFoundError:
6472
pass
6573

6674
try:
67-
return self._get_by_key(key)
75+
return self._get_by_key(key, kind=kind_name)
76+
except NodeInvalidError:
77+
found_invalid = True
6878
except NodeNotFoundError:
6979
pass
7080

71-
kind_name = get_schema_name(schema=kind)
72-
7381
try:
7482
return self._get_by_hfid(key, kind=kind_name)
7583
except NodeNotFoundError:
7684
pass
7785

7886
if not raise_when_missing:
7987
return None
88+
89+
if kind and found_invalid:
90+
raise NodeInvalidError(
91+
node_type="n/a",
92+
identifier={"key": [key]},
93+
message=f"Found a node of a differentkind instead of {kind} for key {key!r} in the store ({self.branch_name})",
94+
)
95+
8096
raise NodeNotFoundError(
8197
node_type="n/a",
8298
identifier={"key": [key]},
8399
message=f"Unable to find the node {key!r} in the store ({self.branch_name})",
84100
)
85101

86-
def _get_by_internal_id(self, internal_id: str) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
102+
def _get_by_internal_id(
103+
self, internal_id: str, kind: str | None = None
104+
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
87105
if internal_id not in self._objs:
88106
raise NodeNotFoundError(
89107
node_type="n/a",
90108
identifier={"internal_id": [internal_id]},
91109
message=f"Unable to find the node {internal_id!r} in the store ({self.branch_name})",
92110
)
93-
return self._objs[internal_id]
94111

95-
def _get_by_key(self, key: str) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
112+
node = self._objs[internal_id]
113+
if kind and kind not in node.get_all_kinds():
114+
raise NodeInvalidError(
115+
node_type=kind,
116+
identifier={"internal_id": [internal_id]},
117+
message=f"Found a node of kind {node.get_kind()} instead of {kind} for internal_id {internal_id!r} in the store ({self.branch_name})",
118+
)
119+
120+
return node
121+
122+
def _get_by_key(
123+
self, key: str, kind: str | None = None
124+
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
96125
if key not in self._keys:
97126
raise NodeNotFoundError(
98127
node_type="n/a",
99128
identifier={"key": [key]},
100129
message=f"Unable to find the node {key!r} in the store ({self.branch_name})",
101130
)
102-
return self._get_by_internal_id(self._keys[key])
103131

104-
def _get_by_id(self, id: str) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
132+
node = self._get_by_internal_id(self._keys[key])
133+
134+
if kind and node.get_kind() != kind:
135+
raise NodeInvalidError(
136+
node_type=kind,
137+
identifier={"key": [key]},
138+
message=f"Found a node of kind {node.get_kind()} instead of {kind} for key {key!r} in the store ({self.branch_name})",
139+
)
140+
141+
return node
142+
143+
def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
105144
if id not in self._uuids:
106145
raise NodeNotFoundError(
107146
node_type="n/a",
108147
identifier={"id": [id]},
109148
message=f"Unable to find the node {id!r} in the store ({self.branch_name})",
110149
)
111-
return self._get_by_internal_id(self._uuids[id])
150+
151+
node = self._get_by_internal_id(self._uuids[id])
152+
if kind and kind not in node.get_all_kinds():
153+
raise NodeInvalidError(
154+
node_type=kind,
155+
identifier={"id": [id]},
156+
message=f"Found a node of kind {node.get_kind()} instead of {kind} for id {id!r} in the store ({self.branch_name})",
157+
)
158+
159+
return node
112160

113161
def _get_by_hfid(
114162
self, hfid: str, kind: str | None = None

tests/unit/sdk/test_store.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from infrahub_sdk.exceptions import NodeNotFoundError
3+
from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError
44
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
55
from infrahub_sdk.store import NodeStore, NodeStoreSync
66

@@ -91,6 +91,8 @@ def test_node_store_get(client_type, clients, location_schema):
9191
assert store.get(kind="BuiltinLocation", key="mykey").id == node.id
9292
assert store.get(key="mykey").id == node.id
9393

94+
assert store.get(kind="BuiltinTest", key="mykey", raise_when_missing=False) is None
95+
9496
assert store.get(kind="BuiltinLocation", key="anotherkey", raise_when_missing=False) is None
9597
assert store.get(key="anotherkey", raise_when_missing=False) is None
9698

@@ -101,6 +103,9 @@ def test_node_store_get(client_type, clients, location_schema):
101103
with pytest.raises(NodeNotFoundError):
102104
store.get(key="mykey", branch="mybranch")
103105

106+
with pytest.raises(NodeInvalidError):
107+
store.get(kind="BuiltinTest", key="mykey")
108+
104109
store.set(key="mykey", node=node, branch="mybranch")
105110
assert store.get(key="mykey", branch="mybranch").id == node.id
106111

0 commit comments

Comments
 (0)