Skip to content

Commit 57441fc

Browse files
ply lazy initialization to both sync and async DB connections and Add CORS
1 parent c5acda9 commit 57441fc

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

mxgo/api.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import redis.asyncio as aioredis
1212
from dotenv import load_dotenv
1313
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response, UploadFile, status
14+
from fastapi.middleware.cors import CORSMiddleware
1415
from fastapi.security import APIKeyHeader, HTTPBearer
1516
from sqlalchemy import text
1617

@@ -127,6 +128,30 @@ async def lifespan(_app: FastAPI):
127128

128129

129130
app = FastAPI(lifespan=lifespan)
131+
132+
IS_PROD = os.getenv("IS_PROD", "false").lower() == "true"
133+
134+
ALLOWED_ORIGINS_PROD = [
135+
"https://mxgo.ai",
136+
]
137+
138+
139+
ALLOWED_ORIGINS_DEV = [
140+
"http://localhost",
141+
"http://localhost:8080",
142+
"http://127.0.0.1",
143+
"http://127.0.0.1:8080",
144+
]
145+
146+
app.add_middleware(
147+
CORSMiddleware,
148+
allow_origins=ALLOWED_ORIGINS_PROD if IS_PROD else ALLOWED_ORIGINS_DEV,
149+
allow_credentials=True,
150+
allow_methods=["GET", "POST"],
151+
allow_headers=["*"],
152+
)
153+
154+
130155
if os.getenv("IS_PROD", "false").lower() == "true":
131156
app.openapi_url = None
132157

mxgo/db/__init__.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,36 @@ class DbConnection:
1414
_engine = None
1515

1616
def __init__(self) -> None:
17-
self.db_uri = self.get_db_uri_from_env()
17+
pass
1818

1919
@classmethod
2020
def get_db_uri_from_env(cls) -> str:
2121
return f"postgresql://{os.environ['DB_USER']}:{os.environ['DB_PASSWORD']}@{os.environ['DB_HOST']}:{os.environ['DB_PORT']}/{os.environ['DB_NAME']}"
2222

2323
def init_connection(self) -> None:
2424
"""Initialize the database connection."""
25-
if self._engine is None:
25+
if DbConnection._engine is None:
2626
self._create_engine()
2727

2828
def close_connection(self) -> None:
2929
"""Close the database connection."""
30-
if self._engine is not None:
31-
self._engine.dispose()
30+
if DbConnection._engine is not None:
31+
DbConnection._engine.dispose()
32+
DbConnection._engine = None
3233

3334
def get_connection(self):
3435
"""Get the database connection engine."""
35-
if self._engine is None:
36-
msg = "DB session isn't initialized"
36+
if DbConnection._engine is None:
37+
msg = "DB session isn't initialized. Call init_db_connection() at application startup."
3738
raise ConnectionError(msg)
38-
return self._engine
39+
return DbConnection._engine
3940

4041
@contextmanager
4142
def get_session(self) -> Generator[Session]:
4243
"""Get a synchronous database session."""
43-
if self._engine is None:
44-
self.init_connection()
44+
engine = self.get_connection()
4545

46-
session = Session(self._engine)
46+
session = Session(engine)
4747
try:
4848
yield session
4949
session.commit()
@@ -55,9 +55,9 @@ def get_session(self) -> Generator[Session]:
5555

5656
def _create_engine(self):
5757
"""Create a synchronous SQLAlchemy engine."""
58-
db_url = self.db_uri
59-
self._engine = create_engine(db_url, pool_pre_ping=True, echo=False)
60-
return self._engine
58+
db_url = self.get_db_uri_from_env()
59+
DbConnection._engine = create_engine(db_url, pool_pre_ping=True, echo=False)
60+
return DbConnection._engine
6161

6262

6363
class AsyncDbConnection:
@@ -66,25 +66,26 @@ class AsyncDbConnection:
6666
_engine: AsyncEngine | None = None
6767

6868
def __init__(self) -> None:
69-
self.db_uri = self.get_db_uri_from_env()
69+
pass
7070

7171
@classmethod
7272
def get_db_uri_from_env(cls) -> str:
7373
return f"postgresql+asyncpg://{os.environ['DB_USER']}:{os.environ['DB_PASSWORD']}@{os.environ['DB_HOST']}:{os.environ['DB_PORT']}/{os.environ['DB_NAME']}"
7474

7575
async def init_connection(self) -> None:
76-
if self._engine is None:
76+
if AsyncDbConnection._engine is None:
7777
await self._create_engine()
7878

7979
async def close_connection(self) -> None:
80-
if self._engine is not None:
81-
await self._engine.dispose()
80+
if AsyncDbConnection._engine is not None:
81+
await AsyncDbConnection._engine.dispose()
82+
AsyncDbConnection._engine = None
8283

8384
def get_connection(self) -> AsyncEngine:
84-
if self._engine is None:
85+
if AsyncDbConnection._engine is None:
8586
msg = "Async DB session isn't initialized"
8687
raise ConnectionError(msg)
87-
return self._engine
88+
return AsyncDbConnection._engine
8889

8990
@asynccontextmanager
9091
async def get_session(self) -> AsyncGenerator[AsyncSession]:
@@ -106,19 +107,21 @@ async def get_session(self) -> AsyncGenerator[AsyncSession]:
106107

107108
async def _create_engine(self) -> AsyncEngine:
108109
db_url = self.db_uri
109-
self._engine = create_async_engine(db_url, pool_pre_ping=True, echo=False)
110-
return self._engine
110+
AsyncDbConnection._engine = create_async_engine(db_url, pool_pre_ping=True, echo=False)
111+
return AsyncDbConnection._engine
112+
113+
114+
db_connection = DbConnection()
115+
async_db_connection = AsyncDbConnection()
111116

112117

113118
def init_db_connection() -> DbConnection:
114119
"""Initialize a synchronous database connection."""
115-
db_connection = DbConnection()
116120
db_connection.init_connection()
117121
return db_connection
118122

119123

120124
async def init_async_db_connection() -> AsyncDbConnection:
121125
"""Initialize an asynchronous database connection for future use."""
122-
db_connection = AsyncDbConnection()
123-
await db_connection.init_connection()
124-
return db_connection
126+
await async_db_connection.init_connection()
127+
return async_db_connection

0 commit comments

Comments
 (0)