11import logging
2+ from collections .abc import AsyncIterator
23from contextlib import asynccontextmanager
34
45from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
910@asynccontextmanager
1011async def get_or_create_connection (
1112 engine : AsyncEngine , connection : AsyncConnection | None = None
12- ):
13- # creator is responsible of closing connection
13+ ) -> AsyncIterator [AsyncConnection ]:
14+ # NOTE: When connection is passed, the engine is actually not needed
15+ # NOTE: Creator is responsible of closing connection
1416 is_connection_created = connection is None
1517 if is_connection_created :
1618 connection = await engine .connect ()
1719 try :
1820 yield connection
1921 finally :
20- if is_connection_created :
22+ assert connection # nosec
23+ if is_connection_created and connection :
2124 await connection .close ()
2225
2326
@@ -27,11 +30,11 @@ async def transaction_context(
2730):
2831 async with get_or_create_connection (engine , connection ) as conn :
2932 if conn .in_transaction ():
30- async with conn .begin_nested (): # savepoint
33+ async with conn .begin_nested (): # inner transaction ( savepoint)
3134 yield conn
3235 else :
3336 try :
34- async with conn .begin ():
37+ async with conn .begin (): # outer transaction (savepoint)
3538 yield conn
3639 finally :
3740 assert not conn .closed # nosec
0 commit comments