Skip to content

Commit 204c9f1

Browse files
committed
Support hfid as list
1 parent e6c0e1d commit 204c9f1

File tree

2 files changed

+118
-57
lines changed

2 files changed

+118
-57
lines changed

infrahub_sdk/node.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@
4646
HFID_STR_SEPARATOR = "__"
4747

4848

49-
def parse_human_friendly_id(hfid: str) -> tuple[str | None, list[str]]:
49+
def parse_human_friendly_id(hfid: str | list[str]) -> tuple[str | None, list[str]]:
5050
"""Parse a human friendly ID into a kind and an identifier."""
51-
hfid_parts = hfid.split(HFID_STR_SEPARATOR)
52-
if len(hfid_parts) == 1:
53-
return None, hfid_parts
54-
return hfid_parts[0], hfid_parts[1:]
51+
if isinstance(hfid, str):
52+
hfid_parts = hfid.split(HFID_STR_SEPARATOR)
53+
if len(hfid_parts) == 1:
54+
return None, hfid_parts
55+
return hfid_parts[0], hfid_parts[1:]
56+
if isinstance(hfid, list):
57+
return None, hfid
58+
raise ValueError(f"Invalid human friendly ID: {hfid}")
5559

5660

5761
class Attribute:

infrahub_sdk/store.py

Lines changed: 109 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -51,51 +51,58 @@ def set(self, node: InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync, k
5151

5252
def get(
5353
self,
54-
key: str,
54+
key: str | list[str],
5555
kind: type[SchemaType | SchemaTypeSync] | str | None = None,
5656
raise_when_missing: bool = True,
5757
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync | None:
5858
found_invalid = False
5959

6060
kind_name = get_schema_name(schema=kind)
6161

62-
try:
63-
return self._get_by_internal_id(key, kind=kind_name)
64-
except NodeInvalidError:
65-
found_invalid = True
66-
except NodeNotFoundError:
67-
pass
68-
69-
try:
70-
return self._get_by_id(key, kind=kind_name)
71-
except NodeInvalidError:
72-
found_invalid = True
73-
except NodeNotFoundError:
74-
pass
75-
76-
try:
77-
return self._get_by_key(key, kind=kind_name)
78-
except NodeInvalidError:
79-
found_invalid = True
80-
except NodeNotFoundError:
81-
pass
82-
83-
try:
84-
return self._get_by_hfid(key, kind=kind_name)
85-
except NodeNotFoundError:
86-
pass
62+
if isinstance(key, list):
63+
try:
64+
return self._get_by_hfid(key, kind=kind_name)
65+
except NodeNotFoundError:
66+
pass
67+
68+
elif isinstance(key, str):
69+
try:
70+
return self._get_by_internal_id(key, kind=kind_name)
71+
except NodeInvalidError:
72+
found_invalid = True
73+
except NodeNotFoundError:
74+
pass
75+
76+
try:
77+
return self._get_by_id(key, kind=kind_name)
78+
except NodeInvalidError:
79+
found_invalid = True
80+
except NodeNotFoundError:
81+
pass
82+
83+
try:
84+
return self._get_by_key(key, kind=kind_name)
85+
except NodeInvalidError:
86+
found_invalid = True
87+
except NodeNotFoundError:
88+
pass
89+
90+
try:
91+
return self._get_by_hfid(key, kind=kind_name)
92+
except NodeNotFoundError:
93+
pass
8794

8895
if not raise_when_missing:
8996
return None
9097

9198
if kind and found_invalid:
9299
raise NodeInvalidError(
93-
identifier={"key": [key]},
100+
identifier={"key": [key] if isinstance(key, str) else key},
94101
message=f"Found a node of a different kind instead of {kind} for key {key!r} in the store ({self.branch_name})",
95102
)
96103

97104
raise NodeNotFoundError(
98-
identifier={"key": [key]},
105+
identifier={"key": [key] if isinstance(key, str) else key},
99106
message=f"Unable to find the node {key!r} in the store ({self.branch_name})",
100107
)
101108

@@ -156,15 +163,15 @@ def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | Infrahu
156163
return node
157164

158165
def _get_by_hfid(
159-
self, hfid: str, kind: str | None = None
166+
self, hfid: str | list[str], kind: str | None = None
160167
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
161168
if not kind:
162169
node_kind, node_hfid = parse_human_friendly_id(hfid)
163-
elif kind and hfid.startswith(kind):
170+
elif kind and isinstance(hfid, str) and hfid.startswith(kind):
164171
node_kind, node_hfid = parse_human_friendly_id(hfid)
165172
else:
166173
node_kind = kind
167-
node_hfid = [hfid]
174+
node_hfid = [hfid] if isinstance(hfid, str) else hfid
168175

169176
exception_to_raise_if_not_found = NodeNotFoundError(
170177
node_type=node_kind,
@@ -218,7 +225,7 @@ def _set(
218225

219226
def _get( # type: ignore[no-untyped-def]
220227
self,
221-
key: str,
228+
key: str | list[str],
222229
kind: type[SchemaType | SchemaTypeSync] | str | None = None,
223230
raise_when_missing: bool = True,
224231
branch: str | None = None,
@@ -242,37 +249,61 @@ def count(self, branch: str | None = None) -> int:
242249
class NodeStore(NodeStoreBase):
243250
@overload
244251
def get(
245-
self, key: str, kind: type[SchemaType], raise_when_missing: Literal[True] = True, branch: str | None = ...
252+
self,
253+
key: str | list[str],
254+
kind: type[SchemaType],
255+
raise_when_missing: Literal[True] = True,
256+
branch: str | None = ...,
246257
) -> SchemaType: ...
247258

248259
@overload
249260
def get(
250-
self, key: str, kind: type[SchemaType], raise_when_missing: Literal[False] = False, branch: str | None = ...
261+
self,
262+
key: str | list[str],
263+
kind: type[SchemaType],
264+
raise_when_missing: Literal[False] = False,
265+
branch: str | None = ...,
251266
) -> SchemaType | None: ...
252267

253268
@overload
254269
def get(
255-
self, key: str, kind: type[SchemaType], raise_when_missing: bool = ..., branch: str | None = ...
270+
self,
271+
key: str | list[str],
272+
kind: type[SchemaType],
273+
raise_when_missing: bool = ...,
274+
branch: str | None = ...,
256275
) -> SchemaType: ...
257276

258277
@overload
259278
def get(
260-
self, key: str, kind: str | None = ..., raise_when_missing: Literal[True] = True, branch: str | None = ...
279+
self,
280+
key: str | list[str],
281+
kind: str | None = ...,
282+
raise_when_missing: Literal[True] = True,
283+
branch: str | None = ...,
261284
) -> InfrahubNode: ...
262285

263286
@overload
264287
def get(
265-
self, key: str, kind: str | None = ..., raise_when_missing: Literal[False] = False, branch: str | None = ...
288+
self,
289+
key: str | list[str],
290+
kind: str | None = ...,
291+
raise_when_missing: Literal[False] = False,
292+
branch: str | None = ...,
266293
) -> InfrahubNode | None: ...
267294

268295
@overload
269296
def get(
270-
self, key: str, kind: str | None = ..., raise_when_missing: bool = ..., branch: str | None = ...
297+
self,
298+
key: str | list[str],
299+
kind: str | None = ...,
300+
raise_when_missing: bool = ...,
301+
branch: str | None = ...,
271302
) -> InfrahubNode: ...
272303

273304
def get(
274305
self,
275-
key: str,
306+
key: str | list[str],
276307
kind: str | type[SchemaType] | None = None,
277308
raise_when_missing: bool = True,
278309
branch: str | None = None,
@@ -281,15 +312,17 @@ def get(
281312

282313
@overload
283314
def get_by_hfid(
284-
self, key: str, raise_when_missing: Literal[True] = True, branch: str | None = ...
315+
self, key: str | list[str], raise_when_missing: Literal[True] = True, branch: str | None = ...
285316
) -> InfrahubNode: ...
286317

287318
@overload
288319
def get_by_hfid(
289-
self, key: str, raise_when_missing: Literal[False] = False, branch: str | None = ...
320+
self, key: str | list[str], raise_when_missing: Literal[False] = False, branch: str | None = ...
290321
) -> InfrahubNode | None: ...
291322

292-
def get_by_hfid(self, key: str, raise_when_missing: bool = True, branch: str | None = None) -> InfrahubNode | None:
323+
def get_by_hfid(
324+
self, key: str | list[str], raise_when_missing: bool = True, branch: str | None = None
325+
) -> InfrahubNode | None:
293326
warnings.warn(
294327
"get_by_hfid() is deprecated and will be removed in a future version. Use get() instead.",
295328
DeprecationWarning,
@@ -304,37 +337,61 @@ def set(self, node: InfrahubNode | SchemaType, key: str | None = None, branch: s
304337
class NodeStoreSync(NodeStoreBase):
305338
@overload
306339
def get(
307-
self, key: str, kind: type[SchemaTypeSync], raise_when_missing: Literal[True] = True, branch: str | None = ...
340+
self,
341+
key: str | list[str],
342+
kind: type[SchemaTypeSync],
343+
raise_when_missing: Literal[True] = True,
344+
branch: str | None = ...,
308345
) -> SchemaTypeSync: ...
309346

310347
@overload
311348
def get(
312-
self, key: str, kind: type[SchemaTypeSync], raise_when_missing: Literal[False] = False, branch: str | None = ...
349+
self,
350+
key: str | list[str],
351+
kind: type[SchemaTypeSync],
352+
raise_when_missing: Literal[False] = False,
353+
branch: str | None = ...,
313354
) -> SchemaTypeSync | None: ...
314355

315356
@overload
316357
def get(
317-
self, key: str, kind: type[SchemaTypeSync], raise_when_missing: bool = ..., branch: str | None = ...
358+
self,
359+
key: str | list[str],
360+
kind: type[SchemaTypeSync],
361+
raise_when_missing: bool = ...,
362+
branch: str | None = ...,
318363
) -> SchemaTypeSync: ...
319364

320365
@overload
321366
def get(
322-
self, key: str, kind: str | None = ..., raise_when_missing: Literal[True] = True, branch: str | None = ...
367+
self,
368+
key: str | list[str],
369+
kind: str | None = ...,
370+
raise_when_missing: Literal[True] = True,
371+
branch: str | None = ...,
323372
) -> InfrahubNodeSync: ...
324373

325374
@overload
326375
def get(
327-
self, key: str, kind: str | None = ..., raise_when_missing: Literal[False] = False, branch: str | None = ...
376+
self,
377+
key: str | list[str],
378+
kind: str | None = ...,
379+
raise_when_missing: Literal[False] = False,
380+
branch: str | None = ...,
328381
) -> InfrahubNodeSync | None: ...
329382

330383
@overload
331384
def get(
332-
self, key: str, kind: str | None = ..., raise_when_missing: bool = ..., branch: str | None = ...
385+
self,
386+
key: str | list[str],
387+
kind: str | None = ...,
388+
raise_when_missing: bool = ...,
389+
branch: str | None = ...,
333390
) -> InfrahubNodeSync: ...
334391

335392
def get(
336393
self,
337-
key: str,
394+
key: str | list[str],
338395
kind: str | type[SchemaTypeSync] | None = None,
339396
raise_when_missing: bool = True,
340397
branch: str | None = None,
@@ -343,16 +400,16 @@ def get(
343400

344401
@overload
345402
def get_by_hfid(
346-
self, key: str, raise_when_missing: Literal[True] = True, branch: str | None = ...
403+
self, key: str | list[str], raise_when_missing: Literal[True] = True, branch: str | None = ...
347404
) -> InfrahubNodeSync: ...
348405

349406
@overload
350407
def get_by_hfid(
351-
self, key: str, raise_when_missing: Literal[False] = False, branch: str | None = ...
408+
self, key: str | list[str], raise_when_missing: Literal[False] = False, branch: str | None = ...
352409
) -> InfrahubNodeSync | None: ...
353410

354411
def get_by_hfid(
355-
self, key: str, raise_when_missing: bool = True, branch: str | None = None
412+
self, key: str | list[str], raise_when_missing: bool = True, branch: str | None = None
356413
) -> InfrahubNodeSync | None:
357414
warnings.warn(
358415
"get_by_hfid() is deprecated and will be removed in a future version. Use get() instead.",

0 commit comments

Comments
 (0)