33import asyncio
44import random
55from dataclasses import dataclass
6- from typing import TYPE_CHECKING , Any , Callable , Coroutine , Optional , TypeVar , Union
6+ from typing import TYPE_CHECKING , Any , Callable , Coroutine , TypeVar
77
88from neo4j import (
99 READ_ACCESS ,
@@ -69,7 +69,7 @@ class InfrahubDatabaseSessionMode(InfrahubStringEnum):
6969 WRITE = "write"
7070
7171
72- def get_branch_name (branch : Optional [ Union [ Branch , str ]] = None ) -> str :
72+ def get_branch_name (branch : Branch | str | None = None ) -> str :
7373 if not branch :
7474 return registry .default_branch
7575 if isinstance (branch , str ):
@@ -82,43 +82,39 @@ class DatabaseSchemaManager:
8282 def __init__ (self , db : InfrahubDatabase ) -> None :
8383 self ._db = db
8484
85- def get (self , name : str , branch : Optional [ Union [ Branch , str ]] = None , duplicate : bool = True ) -> MainSchemaTypes :
85+ def get (self , name : str , branch : Branch | str | None = None , duplicate : bool = True ) -> MainSchemaTypes :
8686 branch_name = get_branch_name (branch = branch )
8787 if branch_name not in self ._db ._schemas :
8888 return registry .schema .get (name = name , branch = branch , duplicate = duplicate )
8989 return self ._db ._schemas [branch_name ].get (name = name , duplicate = duplicate )
9090
91- def get_node_schema (
92- self , name : str , branch : Optional [Union [Branch , str ]] = None , duplicate : bool = True
93- ) -> NodeSchema :
91+ def get_node_schema (self , name : str , branch : Branch | str | None = None , duplicate : bool = True ) -> NodeSchema :
9492 schema = self .get (name = name , branch = branch , duplicate = duplicate )
9593 if schema .is_node_schema :
9694 return schema
9795
9896 raise ValueError ("The selected node is not of type NodeSchema" )
9997
100- def set (self , name : str , schema : MainSchemaTypes , branch : Optional [ str ] = None ) -> int :
98+ def set (self , name : str , schema : MainSchemaTypes , branch : str | None = None ) -> int :
10199 branch_name = get_branch_name (branch = branch )
102100 if branch_name not in self ._db ._schemas :
103101 return registry .schema .set (name = name , schema = schema , branch = branch )
104102 return self ._db ._schemas [branch_name ].set (name = name , schema = schema )
105103
106- def has (self , name : str , branch : Optional [ Union [ Branch , str ]] = None ) -> bool :
104+ def has (self , name : str , branch : Branch | str | None = None ) -> bool :
107105 branch_name = get_branch_name (branch = branch )
108106 if branch_name not in self ._db ._schemas :
109107 return registry .schema .has (name = name , branch = branch )
110108 return self ._db ._schemas [branch_name ].has (name = name )
111109
112- def get_full (
113- self , branch : Optional [Union [Branch , str ]] = None , duplicate : bool = True
114- ) -> dict [str , MainSchemaTypes ]:
110+ def get_full (self , branch : Branch | str | None = None , duplicate : bool = True ) -> dict [str , MainSchemaTypes ]:
115111 branch_name = get_branch_name (branch = branch )
116112 if branch_name not in self ._db ._schemas :
117113 return registry .schema .get_full (branch = branch )
118114 return self ._db ._schemas [branch_name ].get_all (duplicate = duplicate )
119115
120116 async def get_full_safe (
121- self , branch : Optional [ Union [ Branch , str ]] = None , duplicate : bool = True
117+ self , branch : Branch | str | None = None , duplicate : bool = True
122118 ) -> dict [str , MainSchemaTypes ]:
123119 await lock .registry .local_schema_wait ()
124120 return self .get_full (branch = branch , duplicate = duplicate )
@@ -206,10 +202,10 @@ def get_context(self) -> dict[str, Any]:
206202
207203 return {}
208204
209- def add_schema (self , schema : SchemaBranch , name : Optional [ str ] = None ) -> None :
205+ def add_schema (self , schema : SchemaBranch , name : str | None = None ) -> None :
210206 self ._schemas [name or schema .name ] = schema
211207
212- def start_session (self , read_only : bool = False , schemas : Optional [ list [SchemaBranch ]] = None ) -> InfrahubDatabase :
208+ def start_session (self , read_only : bool = False , schemas : list [SchemaBranch ] | None = None ) -> InfrahubDatabase :
213209 """Create a new InfrahubDatabase object in Session mode."""
214210 session_mode = InfrahubDatabaseSessionMode .WRITE
215211 if read_only :
@@ -229,7 +225,7 @@ def start_session(self, read_only: bool = False, schemas: Optional[list[SchemaBr
229225 ** context ,
230226 )
231227
232- def start_transaction (self , schemas : Optional [ list [SchemaBranch ]] = None ) -> InfrahubDatabase :
228+ def start_transaction (self , schemas : list [SchemaBranch ] | None = None ) -> InfrahubDatabase :
233229 context = self .get_context ()
234230
235231 return self .__class__ (
@@ -261,7 +257,7 @@ async def session(self) -> AsyncSession:
261257 self ._is_session_local = True
262258 return self ._session
263259
264- async def transaction (self , name : Optional [ str ] ) -> AsyncTransaction :
260+ async def transaction (self , name : str | None ) -> AsyncTransaction :
265261 if self ._transaction :
266262 return self ._transaction
267263
@@ -290,9 +286,9 @@ async def __aenter__(self) -> Self:
290286
291287 async def __aexit__ (
292288 self ,
293- exc_type : Optional [ type [BaseException ]] ,
294- exc_value : Optional [ BaseException ] ,
295- traceback : Optional [ TracebackType ] ,
289+ exc_type : type [BaseException ] | None ,
290+ exc_value : BaseException | None ,
291+ traceback : TracebackType | None ,
296292 ):
297293 if self ._mode == InfrahubDatabaseMode .SESSION :
298294 return await self ._session .close ()
@@ -385,9 +381,9 @@ async def execute_query_with_metadata(
385381 return results , response ._metadata or {}
386382
387383 async def run_query (
388- self , query : str , params : Optional [ dict [str , Any ]] = None , name : Optional [ str ] = "undefined"
384+ self , query : str , params : dict [str , Any ] | None = None , name : str | None = "undefined"
389385 ) -> AsyncResult :
390- _query : Union [ str | Query ] = query
386+ _query : str | Query = query
391387 if self .is_transaction :
392388 execution_method = await self .transaction (name = name )
393389 else :
0 commit comments