Skip to content

Commit 6b7567e

Browse files
Fix DB engine leak by lazy singleton and Add CORS
1 parent c5acda9 commit 6b7567e

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

mxgo/api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import redis.asyncio as aioredis
1212
from dotenv import load_dotenv
1313
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response, UploadFile, status
14+
from fastapi.middleware.cors import CORSMiddleware
1415
from fastapi.security import APIKeyHeader, HTTPBearer
1516
from sqlalchemy import text
1617

@@ -127,6 +128,30 @@ async def lifespan(_app: FastAPI):
127128

128129

129130
app = FastAPI(lifespan=lifespan)
131+
132+
IS_PROD = os.getenv("IS_PROD", "false").lower() == "true"
133+
134+
ALLOWED_ORIGINS_PROD = [
135+
"https://mxgo.ai",
136+
]
137+
138+
139+
ALLOWED_ORIGINS_DEV = [
140+
"http://localhost",
141+
"http://localhost:8080",
142+
"http://127.0.0.1",
143+
"http://127.0.0.1:8080",
144+
]
145+
146+
app.add_middleware(
147+
CORSMiddleware,
148+
allow_origins=ALLOWED_ORIGINS_PROD if IS_PROD else ALLOWED_ORIGINS_DEV,
149+
allow_credentials=True,
150+
allow_methods=["GET", "POST"],
151+
allow_headers=["*"],
152+
)
153+
154+
130155
if os.getenv("IS_PROD", "false").lower() == "true":
131156
app.openapi_url = None
132157

mxgo/db/__init__.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,35 @@ class DbConnection:
1414
_engine = None
1515

1616
def __init__(self) -> None:
17-
self.db_uri = self.get_db_uri_from_env()
17+
pass
1818

1919
@classmethod
2020
def get_db_uri_from_env(cls) -> str:
2121
return f"postgresql://{os.environ['DB_USER']}:{os.environ['DB_PASSWORD']}@{os.environ['DB_HOST']}:{os.environ['DB_PORT']}/{os.environ['DB_NAME']}"
2222

2323
def init_connection(self) -> None:
2424
"""Initialize the database connection."""
25-
if self._engine is None:
25+
if DbConnection._engine is None:
2626
self._create_engine()
2727

2828
def close_connection(self) -> None:
2929
"""Close the database connection."""
30-
if self._engine is not None:
31-
self._engine.dispose()
30+
if DbConnection._engine is not None:
31+
DbConnection._engine.dispose()
32+
DbConnection._engine = None
3233

3334
def get_connection(self):
3435
"""Get the database connection engine."""
35-
if self._engine is None:
36-
msg = "DB session isn't initialized"
37-
raise ConnectionError(msg)
38-
return self._engine
36+
if DbConnection._engine is None:
37+
self.init_connection()
38+
return DbConnection._engine
3939

4040
@contextmanager
4141
def get_session(self) -> Generator[Session]:
4242
"""Get a synchronous database session."""
43-
if self._engine is None:
44-
self.init_connection()
43+
engine = self.get_connection()
4544

46-
session = Session(self._engine)
45+
session = Session(engine)
4746
try:
4847
yield session
4948
session.commit()
@@ -55,9 +54,9 @@ def get_session(self) -> Generator[Session]:
5554

5655
def _create_engine(self):
5756
"""Create a synchronous SQLAlchemy engine."""
58-
db_url = self.db_uri
59-
self._engine = create_engine(db_url, pool_pre_ping=True, echo=False)
60-
return self._engine
57+
db_url = self.get_db_uri_from_env()
58+
DbConnection._engine = create_engine(db_url, pool_pre_ping=True, echo=False)
59+
return DbConnection._engine
6160

6261

6362
class AsyncDbConnection:
@@ -73,18 +72,19 @@ def get_db_uri_from_env(cls) -> str:
7372
return f"postgresql+asyncpg://{os.environ['DB_USER']}:{os.environ['DB_PASSWORD']}@{os.environ['DB_HOST']}:{os.environ['DB_PORT']}/{os.environ['DB_NAME']}"
7473

7574
async def init_connection(self) -> None:
76-
if self._engine is None:
75+
if AsyncDbConnection._engine is None:
7776
await self._create_engine()
7877

7978
async def close_connection(self) -> None:
80-
if self._engine is not None:
81-
await self._engine.dispose()
79+
if AsyncDbConnection._engine is not None:
80+
await AsyncDbConnection._engine.dispose()
81+
AsyncDbConnection._engine = None
8282

8383
def get_connection(self) -> AsyncEngine:
84-
if self._engine is None:
84+
if AsyncDbConnection._engine is None:
8585
msg = "Async DB session isn't initialized"
8686
raise ConnectionError(msg)
87-
return self._engine
87+
return AsyncDbConnection._engine
8888

8989
@asynccontextmanager
9090
async def get_session(self) -> AsyncGenerator[AsyncSession]:
@@ -106,19 +106,21 @@ async def get_session(self) -> AsyncGenerator[AsyncSession]:
106106

107107
async def _create_engine(self) -> AsyncEngine:
108108
db_url = self.db_uri
109-
self._engine = create_async_engine(db_url, pool_pre_ping=True, echo=False)
110-
return self._engine
109+
AsyncDbConnection._engine = create_async_engine(db_url, pool_pre_ping=True, echo=False)
110+
return AsyncDbConnection._engine
111+
112+
113+
db_connection = DbConnection()
114+
async_db_connection = AsyncDbConnection()
111115

112116

113117
def init_db_connection() -> DbConnection:
114118
"""Initialize a synchronous database connection."""
115-
db_connection = DbConnection()
116119
db_connection.init_connection()
117120
return db_connection
118121

119122

120123
async def init_async_db_connection() -> AsyncDbConnection:
121124
"""Initialize an asynchronous database connection for future use."""
122-
db_connection = AsyncDbConnection()
123-
await db_connection.init_connection()
124-
return db_connection
125+
await async_db_connection.init_connection()
126+
return async_db_connection

0 commit comments

Comments
 (0)