@@ -14,36 +14,36 @@ class DbConnection:
1414 _engine = None
1515
1616 def __init__ (self ) -> None :
17- self . db_uri = self . get_db_uri_from_env ()
17+ pass
1818
1919 @classmethod
2020 def get_db_uri_from_env (cls ) -> str :
2121 return f"postgresql://{ os .environ ['DB_USER' ]} :{ os .environ ['DB_PASSWORD' ]} @{ os .environ ['DB_HOST' ]} :{ os .environ ['DB_PORT' ]} /{ os .environ ['DB_NAME' ]} "
2222
2323 def init_connection (self ) -> None :
2424 """Initialize the database connection."""
25- if self ._engine is None :
25+ if DbConnection ._engine is None :
2626 self ._create_engine ()
2727
2828 def close_connection (self ) -> None :
2929 """Close the database connection."""
30- if self ._engine is not None :
31- self ._engine .dispose ()
30+ if DbConnection ._engine is not None :
31+ DbConnection ._engine .dispose ()
32+ DbConnection ._engine = None
3233
3334 def get_connection (self ):
3435 """Get the database connection engine."""
35- if self ._engine is None :
36- msg = "DB session isn't initialized"
36+ if DbConnection ._engine is None :
37+ msg = "DB session isn't initialized. Call init_db_connection() at application startup. "
3738 raise ConnectionError (msg )
38- return self ._engine
39+ return DbConnection ._engine
3940
4041 @contextmanager
4142 def get_session (self ) -> Generator [Session ]:
4243 """Get a synchronous database session."""
43- if self ._engine is None :
44- self .init_connection ()
44+ engine = self .get_connection ()
4545
46- session = Session (self . _engine )
46+ session = Session (engine )
4747 try :
4848 yield session
4949 session .commit ()
@@ -55,9 +55,9 @@ def get_session(self) -> Generator[Session]:
5555
5656 def _create_engine (self ):
5757 """Create a synchronous SQLAlchemy engine."""
58- db_url = self .db_uri
59- self ._engine = create_engine (db_url , pool_pre_ping = True , echo = False )
60- return self ._engine
58+ db_url = self .get_db_uri_from_env ()
59+ DbConnection ._engine = create_engine (db_url , pool_pre_ping = True , echo = False )
60+ return DbConnection ._engine
6161
6262
6363class AsyncDbConnection :
@@ -66,25 +66,26 @@ class AsyncDbConnection:
6666 _engine : AsyncEngine | None = None
6767
6868 def __init__ (self ) -> None :
69- self . db_uri = self . get_db_uri_from_env ()
69+ pass
7070
7171 @classmethod
7272 def get_db_uri_from_env (cls ) -> str :
7373 return f"postgresql+asyncpg://{ os .environ ['DB_USER' ]} :{ os .environ ['DB_PASSWORD' ]} @{ os .environ ['DB_HOST' ]} :{ os .environ ['DB_PORT' ]} /{ os .environ ['DB_NAME' ]} "
7474
7575 async def init_connection (self ) -> None :
76- if self ._engine is None :
76+ if AsyncDbConnection ._engine is None :
7777 await self ._create_engine ()
7878
7979 async def close_connection (self ) -> None :
80- if self ._engine is not None :
81- await self ._engine .dispose ()
80+ if AsyncDbConnection ._engine is not None :
81+ await AsyncDbConnection ._engine .dispose ()
82+ AsyncDbConnection ._engine = None
8283
8384 def get_connection (self ) -> AsyncEngine :
84- if self ._engine is None :
85+ if AsyncDbConnection ._engine is None :
8586 msg = "Async DB session isn't initialized"
8687 raise ConnectionError (msg )
87- return self ._engine
88+ return AsyncDbConnection ._engine
8889
8990 @asynccontextmanager
9091 async def get_session (self ) -> AsyncGenerator [AsyncSession ]:
@@ -106,19 +107,21 @@ async def get_session(self) -> AsyncGenerator[AsyncSession]:
106107
107108 async def _create_engine (self ) -> AsyncEngine :
108109 db_url = self .db_uri
109- self ._engine = create_async_engine (db_url , pool_pre_ping = True , echo = False )
110- return self ._engine
110+ AsyncDbConnection ._engine = create_async_engine (db_url , pool_pre_ping = True , echo = False )
111+ return AsyncDbConnection ._engine
112+
113+
114+ db_connection = DbConnection ()
115+ async_db_connection = AsyncDbConnection ()
111116
112117
113118def init_db_connection () -> DbConnection :
114119 """Initialize a synchronous database connection."""
115- db_connection = DbConnection ()
116120 db_connection .init_connection ()
117121 return db_connection
118122
119123
120124async def init_async_db_connection () -> AsyncDbConnection :
121125 """Initialize an asynchronous database connection for future use."""
122- db_connection = AsyncDbConnection ()
123- await db_connection .init_connection ()
124- return db_connection
126+ await async_db_connection .init_connection ()
127+ return async_db_connection
0 commit comments