11from __future__ import annotations
22
33import warnings
4- from collections import defaultdict
54from typing import TYPE_CHECKING , Literal , overload
65
76from .exceptions import NodeInvalidError , NodeNotFoundError
@@ -28,7 +27,7 @@ def __init__(self, name: str) -> None:
2827 self .branch_name = name
2928
3029 self ._objs : dict [str , InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync ] = {}
31- self ._hfids : dict [str , dict [tuple , str ]] = defaultdict ( dict )
30+ self ._hfids : dict [str , dict [tuple , str ]] = {}
3231 self ._keys : dict [str , str ] = {}
3332 self ._uuids : dict [str , str ] = {}
3433
@@ -45,9 +44,12 @@ def set(self, node: InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync, k
4544 self ._uuids [node .id ] = node ._internal_id
4645
4746 if hfid := node .get_human_friendly_id ():
48- self ._hfids [node .get_kind ()][tuple (hfid )] = node ._internal_id
47+ for kind in node .get_all_kinds ():
48+ if kind not in self ._hfids :
49+ self ._hfids [kind ] = {}
50+ self ._hfids [kind ][tuple (hfid )] = node ._internal_id
4951
50- def get ( # type: ignore[no-untyped-def]
52+ def get (
5153 self ,
5254 key : str ,
5355 kind : type [SchemaType | SchemaTypeSync ] | str | None = None ,
@@ -88,13 +90,11 @@ def get( # type: ignore[no-untyped-def]
8890
8991 if kind and found_invalid :
9092 raise NodeInvalidError (
91- node_type = "n/a" ,
9293 identifier = {"key" : [key ]},
93- message = f"Found a node of a differentkind instead of { kind } for key { key !r} in the store ({ self .branch_name } )" ,
94+ message = f"Found a node of a different kind instead of { kind } for key { key !r} in the store ({ self .branch_name } )" ,
9495 )
9596
9697 raise NodeNotFoundError (
97- node_type = "n/a" ,
9898 identifier = {"key" : [key ]},
9999 message = f"Unable to find the node { key !r} in the store ({ self .branch_name } )" ,
100100 )
@@ -104,7 +104,6 @@ def _get_by_internal_id(
104104 ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync :
105105 if internal_id not in self ._objs :
106106 raise NodeNotFoundError (
107- node_type = "n/a" ,
108107 identifier = {"internal_id" : [internal_id ]},
109108 message = f"Unable to find the node { internal_id !r} in the store ({ self .branch_name } )" ,
110109 )
@@ -124,7 +123,6 @@ def _get_by_key(
124123 ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync :
125124 if key not in self ._keys :
126125 raise NodeNotFoundError (
127- node_type = "n/a" ,
128126 identifier = {"key" : [key ]},
129127 message = f"Unable to find the node { key !r} in the store ({ self .branch_name } )" ,
130128 )
@@ -143,7 +141,6 @@ def _get_by_key(
143141 def _get_by_id (self , id : str , kind : str | None = None ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync :
144142 if id not in self ._uuids :
145143 raise NodeNotFoundError (
146- node_type = "n/a" ,
147144 identifier = {"id" : [id ]},
148145 message = f"Unable to find the node { id !r} in the store ({ self .branch_name } )" ,
149146 )
@@ -170,7 +167,7 @@ def _get_by_hfid(
170167 node_hfid = [hfid ]
171168
172169 exception_to_raise_if_not_found = NodeNotFoundError (
173- node_type = node_kind or "unknown" ,
170+ node_type = node_kind ,
174171 identifier = {"hfid" : node_hfid },
175172 message = f"Unable to find the node { hfid !r} in the store ({ self .branch_name } )" ,
176173 )
@@ -192,17 +189,27 @@ class NodeStoreBase:
192189 we need to save them in order to reuse them later to associate them with another node for example.
193190 """
194191
195- def __init__ (self , default_branch : str = "main" ) -> None :
192+ def __init__ (self , default_branch : str | None = None ) -> None :
196193 self ._branches : dict [str , NodeStoreBranch ] = {}
197194 self ._default_branch = default_branch
198195
196+ def _get_branch (self , branch : str | None = None ) -> str :
197+ branch = branch or self ._default_branch
198+
199+ if branch is None :
200+ raise ValueError (
201+ "A Branch must be provided to use the store, either as a parameter or by setting the default branch on the store"
202+ )
203+
204+ return branch
205+
199206 def _set (
200207 self ,
201208 node : InfrahubNode | InfrahubNodeSync | SchemaType | SchemaTypeSync ,
202209 key : str | None = None ,
203210 branch : str | None = None ,
204211 ) -> None :
205- branch = branch or node .get_branch () or self . _default_branch
212+ branch = self . _get_branch ( branch or node .get_branch ())
206213
207214 if branch not in self ._branches :
208215 self ._branches [branch ] = NodeStoreBranch (name = branch )
@@ -216,15 +223,15 @@ def _get( # type: ignore[no-untyped-def]
216223 raise_when_missing : bool = True ,
217224 branch : str | None = None ,
218225 ):
219- branch = branch or self ._default_branch
226+ branch = self ._get_branch ( branch )
220227
221228 if branch not in self ._branches :
222229 self ._branches [branch ] = NodeStoreBranch (name = branch )
223230
224231 return self ._branches [branch ].get (key = key , kind = kind , raise_when_missing = raise_when_missing )
225232
226233 def count (self , branch : str | None = None ) -> int :
227- branch = branch or self ._default_branch
234+ branch = self ._get_branch ( branch )
228235
229236 if branch not in self ._branches :
230237 return 0
0 commit comments