3939
4040
4141class SQLAlchemyConfig (BaseModel ):
42+ """
43+ Configuration for synchronous engines
44+ """
45+
4246 engine_url : str
4347 engine_options : Union [dict , None ] = None
4448 session_options : Union [dict , None ] = None
4549
4650
4751class SQLAlchemyAsyncConfig (BaseModel ):
52+ """
53+ Configuration for asynchronous engines
54+ """
55+
4856 engine_url : str
4957 engine_options : Union [dict , None ] = None
5058 session_options : Union [dict , None ] = None
@@ -68,11 +76,6 @@ class SQLAlchemyAsyncBind(BaseModel):
6876 model_config = ConfigDict (arbitrary_types_allowed = True )
6977
7078
71- _SQLAlchemyConfig = Union [
72- Mapping [str , Union [SQLAlchemyConfig , SQLAlchemyAsyncConfig ]],
73- SQLAlchemyConfig ,
74- SQLAlchemyAsyncConfig ,
75- ]
7679DEFAULT_BIND_NAME = "default"
7780
7881
@@ -81,7 +84,11 @@ class SQLAlchemyBindManager:
8184
8285 def __init__ (
8386 self ,
84- config : _SQLAlchemyConfig ,
87+ config : Union [
88+ Mapping [str , Union [SQLAlchemyConfig , SQLAlchemyAsyncConfig ]],
89+ SQLAlchemyConfig ,
90+ SQLAlchemyAsyncConfig ,
91+ ],
8592 ) -> None :
8693 self .__binds = {}
8794 if isinstance (config , Mapping ):
@@ -162,31 +169,54 @@ def __build_async_bind(
162169 declarative_base = registry_mapper .generate_base (),
163170 )
164171
165- def get_binds (self ) -> Mapping [str , Union [SQLAlchemyBind , SQLAlchemyAsyncBind ]]:
166- return self .__binds
167-
168172 def get_bind_mappers_metadata (self ) -> Mapping [str , MetaData ]:
169173 """
170- Returns the mappers metadata in a format that can be used
171- in Alembic configuration
174+ Returns the registered mappers metadata in a format
175+ that can be used in Alembic configuration
172176
173177 :returns: mappers metadata
174- :rtype: dict
175178 """
176179 return {k : b .registry_mapper .metadata for k , b in self .__binds .items ()}
177180
178181 def get_bind (
179182 self , bind_name : str = DEFAULT_BIND_NAME
180183 ) -> Union [SQLAlchemyBind , SQLAlchemyAsyncBind ]:
184+ """
185+ Returns a bind object by name.
186+
187+ :param bind_name: A registered bind name
188+ :return: a bind object
189+ """
181190 try :
182191 return self .__binds [bind_name ]
183192 except KeyError :
184193 raise NotInitializedBindError ("Bind not initialized" )
185194
195+ def get_binds (self ) -> Mapping [str , Union [SQLAlchemyBind , SQLAlchemyAsyncBind ]]:
196+ """
197+ Returns all the registered bind objects.
198+
199+ :returns: A mapping containing the registered binds
200+ """
201+ return self .__binds
202+
203+ def get_mapper (self , bind_name : str = DEFAULT_BIND_NAME ) -> registry :
204+ """
205+ Returns the registered SQLAlchemy registry_mapper for the given bind name
206+
207+ :param bind_name: A registered bind name
208+ :return: the registered registry_mapper
209+ """
210+ return self .get_bind (bind_name ).registry_mapper
211+
186212 def get_session (
187213 self , bind_name : str = DEFAULT_BIND_NAME
188214 ) -> Union [Session , AsyncSession ]:
189- return self .get_bind (bind_name ).session_class ()
215+ """
216+ Returns a SQLAlchemy Session object, ready to be used either
217+ directly or as a context manager
190218
191- def get_mapper (self , bind_name : str = DEFAULT_BIND_NAME ) -> registry :
192- return self .get_bind (bind_name ).registry_mapper
219+ :param bind_name: A registered bind name
220+ :return: The SQLAlchemy Session object
221+ """
222+ return self .get_bind (bind_name ).session_class ()
0 commit comments