@@ -51,51 +51,58 @@ def set(self, node: InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync, k
5151
5252    def  get (
5353        self ,
54-         key : str ,
54+         key : str   |   list [ str ] ,
5555        kind : type [SchemaType  |  SchemaTypeSync ] |  str  |  None  =  None ,
5656        raise_when_missing : bool  =  True ,
5757    ) ->  InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync  |  None :
5858        found_invalid  =  False 
5959
6060        kind_name  =  get_schema_name (schema = kind )
6161
62-         try :
63-             return  self ._get_by_internal_id (key , kind = kind_name )
64-         except  NodeInvalidError :
65-             found_invalid  =  True 
66-         except  NodeNotFoundError :
67-             pass 
68- 
69-         try :
70-             return  self ._get_by_id (key , kind = kind_name )
71-         except  NodeInvalidError :
72-             found_invalid  =  True 
73-         except  NodeNotFoundError :
74-             pass 
75- 
76-         try :
77-             return  self ._get_by_key (key , kind = kind_name )
78-         except  NodeInvalidError :
79-             found_invalid  =  True 
80-         except  NodeNotFoundError :
81-             pass 
82- 
83-         try :
84-             return  self ._get_by_hfid (key , kind = kind_name )
85-         except  NodeNotFoundError :
86-             pass 
62+         if  isinstance (key , list ):
63+             try :
64+                 return  self ._get_by_hfid (key , kind = kind_name )
65+             except  NodeNotFoundError :
66+                 pass 
67+ 
68+         elif  isinstance (key , str ):
69+             try :
70+                 return  self ._get_by_internal_id (key , kind = kind_name )
71+             except  NodeInvalidError :
72+                 found_invalid  =  True 
73+             except  NodeNotFoundError :
74+                 pass 
75+ 
76+             try :
77+                 return  self ._get_by_id (key , kind = kind_name )
78+             except  NodeInvalidError :
79+                 found_invalid  =  True 
80+             except  NodeNotFoundError :
81+                 pass 
82+ 
83+             try :
84+                 return  self ._get_by_key (key , kind = kind_name )
85+             except  NodeInvalidError :
86+                 found_invalid  =  True 
87+             except  NodeNotFoundError :
88+                 pass 
89+ 
90+             try :
91+                 return  self ._get_by_hfid (key , kind = kind_name )
92+             except  NodeNotFoundError :
93+                 pass 
8794
8895        if  not  raise_when_missing :
8996            return  None 
9097
9198        if  kind  and  found_invalid :
9299            raise  NodeInvalidError (
93-                 identifier = {"key" : [key ]},
100+                 identifier = {"key" : [key ]  if   isinstance ( key ,  str )  else   key },
94101                message = f"Found a node of a different kind instead of { kind } { key !r} { self .branch_name }  ,
95102            )
96103
97104        raise  NodeNotFoundError (
98-             identifier = {"key" : [key ]},
105+             identifier = {"key" : [key ]  if   isinstance ( key ,  str )  else   key },
99106            message = f"Unable to find the node { key !r} { self .branch_name }  ,
100107        )
101108
@@ -156,15 +163,15 @@ def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | Infrahu
156163        return  node 
157164
158165    def  _get_by_hfid (
159-         self , hfid : str , kind : str  |  None  =  None 
166+         self , hfid : str   |   list [ str ] , kind : str  |  None  =  None 
160167    ) ->  InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync :
161168        if  not  kind :
162169            node_kind , node_hfid  =  parse_human_friendly_id (hfid )
163-         elif  kind  and  hfid .startswith (kind ):
170+         elif  kind  and  isinstance ( hfid ,  str )  and   hfid .startswith (kind ):
164171            node_kind , node_hfid  =  parse_human_friendly_id (hfid )
165172        else :
166173            node_kind  =  kind 
167-             node_hfid  =  [hfid ]
174+             node_hfid  =  [hfid ]  if   isinstance ( hfid ,  str )  else   hfid 
168175
169176        exception_to_raise_if_not_found  =  NodeNotFoundError (
170177            node_type = node_kind ,
@@ -218,7 +225,7 @@ def _set(
218225
219226    def  _get (  # type: ignore[no-untyped-def] 
220227        self ,
221-         key : str ,
228+         key : str   |   list [ str ] ,
222229        kind : type [SchemaType  |  SchemaTypeSync ] |  str  |  None  =  None ,
223230        raise_when_missing : bool  =  True ,
224231        branch : str  |  None  =  None ,
@@ -242,37 +249,61 @@ def count(self, branch: str | None = None) -> int:
242249class  NodeStore (NodeStoreBase ):
243250    @overload  
244251    def  get (
245-         self , key : str , kind : type [SchemaType ], raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
252+         self ,
253+         key : str  |  list [str ],
254+         kind : type [SchemaType ],
255+         raise_when_missing : Literal [True ] =  True ,
256+         branch : str  |  None  =  ...,
246257    ) ->  SchemaType : ...
247258
248259    @overload  
249260    def  get (
250-         self , key : str , kind : type [SchemaType ], raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
261+         self ,
262+         key : str  |  list [str ],
263+         kind : type [SchemaType ],
264+         raise_when_missing : Literal [False ] =  False ,
265+         branch : str  |  None  =  ...,
251266    ) ->  SchemaType  |  None : ...
252267
253268    @overload  
254269    def  get (
255-         self , key : str , kind : type [SchemaType ], raise_when_missing : bool  =  ..., branch : str  |  None  =  ...
270+         self ,
271+         key : str  |  list [str ],
272+         kind : type [SchemaType ],
273+         raise_when_missing : bool  =  ...,
274+         branch : str  |  None  =  ...,
256275    ) ->  SchemaType : ...
257276
258277    @overload  
259278    def  get (
260-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
279+         self ,
280+         key : str  |  list [str ],
281+         kind : str  |  None  =  ...,
282+         raise_when_missing : Literal [True ] =  True ,
283+         branch : str  |  None  =  ...,
261284    ) ->  InfrahubNode : ...
262285
263286    @overload  
264287    def  get (
265-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
288+         self ,
289+         key : str  |  list [str ],
290+         kind : str  |  None  =  ...,
291+         raise_when_missing : Literal [False ] =  False ,
292+         branch : str  |  None  =  ...,
266293    ) ->  InfrahubNode  |  None : ...
267294
268295    @overload  
269296    def  get (
270-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : bool  =  ..., branch : str  |  None  =  ...
297+         self ,
298+         key : str  |  list [str ],
299+         kind : str  |  None  =  ...,
300+         raise_when_missing : bool  =  ...,
301+         branch : str  |  None  =  ...,
271302    ) ->  InfrahubNode : ...
272303
273304    def  get (
274305        self ,
275-         key : str ,
306+         key : str   |   list [ str ] ,
276307        kind : str  |  type [SchemaType ] |  None  =  None ,
277308        raise_when_missing : bool  =  True ,
278309        branch : str  |  None  =  None ,
@@ -281,15 +312,17 @@ def get(
281312
282313    @overload  
283314    def  get_by_hfid (
284-         self , key : str , raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
315+         self , key : str   |   list [ str ] , raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
285316    ) ->  InfrahubNode : ...
286317
287318    @overload  
288319    def  get_by_hfid (
289-         self , key : str , raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
320+         self , key : str   |   list [ str ] , raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
290321    ) ->  InfrahubNode  |  None : ...
291322
292-     def  get_by_hfid (self , key : str , raise_when_missing : bool  =  True , branch : str  |  None  =  None ) ->  InfrahubNode  |  None :
323+     def  get_by_hfid (
324+         self , key : str  |  list [str ], raise_when_missing : bool  =  True , branch : str  |  None  =  None 
325+     ) ->  InfrahubNode  |  None :
293326        warnings .warn (
294327            "get_by_hfid() is deprecated and will be removed in a future version. Use get() instead." ,
295328            DeprecationWarning ,
@@ -304,37 +337,61 @@ def set(self, node: InfrahubNode | SchemaType, key: str | None = None, branch: s
304337class  NodeStoreSync (NodeStoreBase ):
305338    @overload  
306339    def  get (
307-         self , key : str , kind : type [SchemaTypeSync ], raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
340+         self ,
341+         key : str  |  list [str ],
342+         kind : type [SchemaTypeSync ],
343+         raise_when_missing : Literal [True ] =  True ,
344+         branch : str  |  None  =  ...,
308345    ) ->  SchemaTypeSync : ...
309346
310347    @overload  
311348    def  get (
312-         self , key : str , kind : type [SchemaTypeSync ], raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
349+         self ,
350+         key : str  |  list [str ],
351+         kind : type [SchemaTypeSync ],
352+         raise_when_missing : Literal [False ] =  False ,
353+         branch : str  |  None  =  ...,
313354    ) ->  SchemaTypeSync  |  None : ...
314355
315356    @overload  
316357    def  get (
317-         self , key : str , kind : type [SchemaTypeSync ], raise_when_missing : bool  =  ..., branch : str  |  None  =  ...
358+         self ,
359+         key : str  |  list [str ],
360+         kind : type [SchemaTypeSync ],
361+         raise_when_missing : bool  =  ...,
362+         branch : str  |  None  =  ...,
318363    ) ->  SchemaTypeSync : ...
319364
320365    @overload  
321366    def  get (
322-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
367+         self ,
368+         key : str  |  list [str ],
369+         kind : str  |  None  =  ...,
370+         raise_when_missing : Literal [True ] =  True ,
371+         branch : str  |  None  =  ...,
323372    ) ->  InfrahubNodeSync : ...
324373
325374    @overload  
326375    def  get (
327-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
376+         self ,
377+         key : str  |  list [str ],
378+         kind : str  |  None  =  ...,
379+         raise_when_missing : Literal [False ] =  False ,
380+         branch : str  |  None  =  ...,
328381    ) ->  InfrahubNodeSync  |  None : ...
329382
330383    @overload  
331384    def  get (
332-         self , key : str , kind : str  |  None  =  ..., raise_when_missing : bool  =  ..., branch : str  |  None  =  ...
385+         self ,
386+         key : str  |  list [str ],
387+         kind : str  |  None  =  ...,
388+         raise_when_missing : bool  =  ...,
389+         branch : str  |  None  =  ...,
333390    ) ->  InfrahubNodeSync : ...
334391
335392    def  get (
336393        self ,
337-         key : str ,
394+         key : str   |   list [ str ] ,
338395        kind : str  |  type [SchemaTypeSync ] |  None  =  None ,
339396        raise_when_missing : bool  =  True ,
340397        branch : str  |  None  =  None ,
@@ -343,16 +400,16 @@ def get(
343400
344401    @overload  
345402    def  get_by_hfid (
346-         self , key : str , raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
403+         self , key : str   |   list [ str ] , raise_when_missing : Literal [True ] =  True , branch : str  |  None  =  ...
347404    ) ->  InfrahubNodeSync : ...
348405
349406    @overload  
350407    def  get_by_hfid (
351-         self , key : str , raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
408+         self , key : str   |   list [ str ] , raise_when_missing : Literal [False ] =  False , branch : str  |  None  =  ...
352409    ) ->  InfrahubNodeSync  |  None : ...
353410
354411    def  get_by_hfid (
355-         self , key : str , raise_when_missing : bool  =  True , branch : str  |  None  =  None 
412+         self , key : str   |   list [ str ] , raise_when_missing : bool  =  True , branch : str  |  None  =  None 
356413    ) ->  InfrahubNodeSync  |  None :
357414        warnings .warn (
358415            "get_by_hfid() is deprecated and will be removed in a future version. Use get() instead." ,
0 commit comments