11from  __future__ import  annotations 
22
33import  warnings 
4- from  collections  import  defaultdict 
54from  typing  import  TYPE_CHECKING , Literal , overload 
65
76from  .exceptions  import  NodeInvalidError , NodeNotFoundError 
@@ -28,7 +27,7 @@ def __init__(self, name: str) -> None:
2827        self .branch_name  =  name 
2928
3029        self ._objs : dict [str , InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync ] =  {}
31-         self ._hfids : dict [str , dict [tuple , str ]] =  defaultdict ( dict ) 
30+         self ._hfids : dict [str , dict [tuple , str ]] =  {} 
3231        self ._keys : dict [str , str ] =  {}
3332        self ._uuids : dict [str , str ] =  {}
3433
@@ -45,9 +44,12 @@ def set(self, node: InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync, k
4544            self ._uuids [node .id ] =  node ._internal_id 
4645
4746        if  hfid  :=  node .get_human_friendly_id ():
48-             self ._hfids [node .get_kind ()][tuple (hfid )] =  node ._internal_id 
47+             for  kind  in  node .get_all_kinds ():
48+                 if  kind  not  in self ._hfids :
49+                     self ._hfids [kind ] =  {}
50+                 self ._hfids [kind ][tuple (hfid )] =  node ._internal_id 
4951
50-     def  get (   # type: ignore[no-untyped-def] 
52+     def  get (
5153        self ,
5254        key : str ,
5355        kind : type [SchemaType  |  SchemaTypeSync ] |  str  |  None  =  None ,
@@ -88,13 +90,11 @@ def get(  # type: ignore[no-untyped-def]
8890
8991        if  kind  and  found_invalid :
9092            raise  NodeInvalidError (
91-                 node_type = "n/a" ,
9293                identifier = {"key" : [key ]},
93-                 message = f"Found a node of a differentkind  instead of { kind } { key !r} { self .branch_name }  ,
94+                 message = f"Found a node of a different kind  instead of { kind } { key !r} { self .branch_name }  ,
9495            )
9596
9697        raise  NodeNotFoundError (
97-             node_type = "n/a" ,
9898            identifier = {"key" : [key ]},
9999            message = f"Unable to find the node { key !r} { self .branch_name }  ,
100100        )
@@ -104,7 +104,6 @@ def _get_by_internal_id(
104104    ) ->  InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync :
105105        if  internal_id  not  in self ._objs :
106106            raise  NodeNotFoundError (
107-                 node_type = "n/a" ,
108107                identifier = {"internal_id" : [internal_id ]},
109108                message = f"Unable to find the node { internal_id !r} { self .branch_name }  ,
110109            )
@@ -124,7 +123,6 @@ def _get_by_key(
124123    ) ->  InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync :
125124        if  key  not  in self ._keys :
126125            raise  NodeNotFoundError (
127-                 node_type = "n/a" ,
128126                identifier = {"key" : [key ]},
129127                message = f"Unable to find the node { key !r} { self .branch_name }  ,
130128            )
@@ -143,7 +141,6 @@ def _get_by_key(
143141    def  _get_by_id (self , id : str , kind : str  |  None  =  None ) ->  InfrahubNode  |  InfrahubNodeSync  |  CoreNode  |  CoreNodeSync :
144142        if  id  not  in self ._uuids :
145143            raise  NodeNotFoundError (
146-                 node_type = "n/a" ,
147144                identifier = {"id" : [id ]},
148145                message = f"Unable to find the node { id !r} { self .branch_name }  ,
149146            )
@@ -170,7 +167,7 @@ def _get_by_hfid(
170167            node_hfid  =  [hfid ]
171168
172169        exception_to_raise_if_not_found  =  NodeNotFoundError (
173-             node_type = node_kind   or   "unknown" ,
170+             node_type = node_kind ,
174171            identifier = {"hfid" : node_hfid },
175172            message = f"Unable to find the node { hfid !r} { self .branch_name }  ,
176173        )
@@ -192,17 +189,27 @@ class NodeStoreBase:
192189    we need to save them in order to reuse them later to associate them with another node for example. 
193190    """ 
194191
195-     def  __init__ (self , default_branch : str  =   "main" ) ->  None :
192+     def  __init__ (self , default_branch : str  |   None   =   None ) ->  None :
196193        self ._branches : dict [str , NodeStoreBranch ] =  {}
197194        self ._default_branch  =  default_branch 
198195
196+     def  _get_branch (self , branch : str  |  None  =  None ) ->  str :
197+         branch  =  branch  or  self ._default_branch 
198+ 
199+         if  branch  is  None :
200+             raise  ValueError (
201+                 "A Branch must be provided to use the store, either as a parameter or by setting the default branch on the store" 
202+             )
203+ 
204+         return  branch 
205+ 
199206    def  _set (
200207        self ,
201208        node : InfrahubNode  |  InfrahubNodeSync  |  SchemaType  |  SchemaTypeSync ,
202209        key : str  |  None  =  None ,
203210        branch : str  |  None  =  None ,
204211    ) ->  None :
205-         branch  =  branch  or  node .get_branch ()  or   self . _default_branch 
212+         branch  =  self . _get_branch ( branch  or  node .get_branch ()) 
206213
207214        if  branch  not  in self ._branches :
208215            self ._branches [branch ] =  NodeStoreBranch (name = branch )
@@ -216,15 +223,15 @@ def _get(  # type: ignore[no-untyped-def]
216223        raise_when_missing : bool  =  True ,
217224        branch : str  |  None  =  None ,
218225    ):
219-         branch  =  branch   or   self ._default_branch 
226+         branch  =  self ._get_branch ( branch ) 
220227
221228        if  branch  not  in self ._branches :
222229            self ._branches [branch ] =  NodeStoreBranch (name = branch )
223230
224231        return  self ._branches [branch ].get (key = key , kind = kind , raise_when_missing = raise_when_missing )
225232
226233    def  count (self , branch : str  |  None  =  None ) ->  int :
227-         branch  =  branch   or   self ._default_branch 
234+         branch  =  self ._get_branch ( branch ) 
228235
229236        if  branch  not  in self ._branches :
230237            return  0 
0 commit comments