Skip to content

Commit 16d99f2

Browse files
authored
Add back postgres connection pool (#1334)
1 parent 365115e commit 16d99f2

File tree

1 file changed

+105
-78
lines changed
  • mcp_servers/postgres/src/postgres_mcp

1 file changed

+105
-78
lines changed

mcp_servers/postgres/src/postgres_mcp/server.py

Lines changed: 105 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class AccessMode(str, Enum):
5858

5959

6060
# Global variables
61+
db_connections: dict[str, DbConnPool] = {} # Cache connections by database URL
6162
current_access_mode = AccessMode.UNRESTRICTED
6263
shutdown_in_progress = False
6364

@@ -107,25 +108,39 @@ def get_database_url() -> str:
107108
raise RuntimeError("Database URL not found in request context")
108109

109110

110-
@contextlib.asynccontextmanager
111-
async def create_sql_driver(database_url: str) -> AsyncIterator[Union[SqlDriver, SafeSqlDriver]]:
112-
"""Create a fresh database connection for one tool execution and close it when done."""
111+
def get_database_url_or_empty() -> str:
112+
"""Get the database URL from context or return empty string."""
113+
try:
114+
return database_url_context.get()
115+
except LookupError:
116+
return ""
117+
118+
119+
async def get_db_connection(database_url: str) -> DbConnPool:
120+
"""Get or create a database connection pool for the given URL."""
121+
if database_url not in db_connections:
122+
conn = DbConnPool()
123+
await conn.pool_connect(database_url)
124+
db_connections[database_url] = conn
125+
logger.info("Created new connection pool for database")
126+
return db_connections[database_url]
127+
128+
129+
async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:
130+
"""Get the appropriate SQL driver based on the current access mode and request context."""
131+
database_url = get_database_url()
113132
if not database_url:
114133
raise ValueError("No database URL provided in request. Please include database URL in x-auth-data header.")
115134

116-
conn = DbConnPool()
117-
await conn.pool_connect(database_url)
118-
try:
119-
base_driver = SqlDriver(conn=conn)
120-
if current_access_mode == AccessMode.RESTRICTED:
121-
logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)")
122-
yield SafeSqlDriver(sql_driver=base_driver, timeout=30)
123-
else:
124-
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
125-
yield base_driver
126-
finally:
127-
await conn.close()
128-
logger.debug("Database connection closed after tool execution")
135+
db_connection = await get_db_connection(database_url)
136+
base_driver = SqlDriver(conn=db_connection)
137+
138+
if current_access_mode == AccessMode.RESTRICTED:
139+
logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)")
140+
return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout
141+
else:
142+
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
143+
return base_driver
129144

130145

131146
def format_text_response(text: Any) -> ResponseType:
@@ -139,9 +154,10 @@ def format_error_response(error: str) -> ResponseType:
139154

140155

141156
# Tool implementations
142-
async def list_schemas_tool(sql_driver: Union[SqlDriver, SafeSqlDriver]) -> ResponseType:
157+
async def list_schemas_tool() -> ResponseType:
143158
"""List all schemas in the database."""
144159
try:
160+
sql_driver = await get_sql_driver()
145161
rows = await sql_driver.execute_query(
146162
"""
147163
SELECT
@@ -164,12 +180,13 @@ async def list_schemas_tool(sql_driver: Union[SqlDriver, SafeSqlDriver]) -> Resp
164180

165181

166182
async def list_objects_tool(
167-
sql_driver: Union[SqlDriver, SafeSqlDriver],
168183
schema_name: str,
169184
object_type: str = "table",
170185
) -> ResponseType:
171186
"""List objects of a given type in a schema."""
172187
try:
188+
sql_driver = await get_sql_driver()
189+
173190
if object_type in ("table", "view"):
174191
table_type = "BASE TABLE" if object_type == "table" else "VIEW"
175192
rows = await SafeSqlDriver.execute_param_query(
@@ -230,13 +247,14 @@ async def list_objects_tool(
230247

231248

232249
async def get_object_details_tool(
233-
sql_driver: Union[SqlDriver, SafeSqlDriver],
234250
schema_name: str,
235251
object_name: str,
236252
object_type: str = "table",
237253
) -> ResponseType:
238254
"""Get detailed information about a database object."""
239255
try:
256+
sql_driver = await get_sql_driver()
257+
240258
if object_type in ("table", "view"):
241259
# Get columns
242260
col_rows = await SafeSqlDriver.execute_param_query(
@@ -361,7 +379,6 @@ async def get_object_details_tool(
361379

362380

363381
async def explain_query_tool(
364-
sql_driver: Union[SqlDriver, SafeSqlDriver],
365382
sql: str,
366383
analyze: bool = False,
367384
hypothetical_indexes: list[dict[str, Any]] | None = None,
@@ -370,6 +387,7 @@ async def explain_query_tool(
370387
if hypothetical_indexes is None:
371388
hypothetical_indexes = []
372389
try:
390+
sql_driver = await get_sql_driver()
373391
explain_tool = ExplainPlanTool(sql_driver=sql_driver)
374392
result: ExplainPlanArtifact | ErrorResult | None = None
375393

@@ -417,9 +435,10 @@ async def explain_query_tool(
417435
return format_error_response(str(e))
418436

419437

420-
async def execute_sql_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], sql: str) -> ResponseType:
438+
async def execute_sql_tool(sql: str) -> ResponseType:
421439
"""Executes a SQL query against the database."""
422440
try:
441+
sql_driver = await get_sql_driver()
423442
rows = await sql_driver.execute_query(sql) # type: ignore
424443
if rows is None:
425444
return format_text_response("No results")
@@ -430,12 +449,12 @@ async def execute_sql_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], sql: str
430449

431450

432451
async def analyze_workload_indexes_tool(
433-
sql_driver: Union[SqlDriver, SafeSqlDriver],
434452
max_index_size_mb: int = 10000,
435453
method: Literal["dta", "llm"] = "dta",
436454
) -> ResponseType:
437455
"""Analyze frequently executed queries in the database and recommend optimal indexes."""
438456
try:
457+
sql_driver = await get_sql_driver()
439458
if method == "dta":
440459
index_tuning = DatabaseTuningAdvisor(sql_driver)
441460
else:
@@ -449,7 +468,6 @@ async def analyze_workload_indexes_tool(
449468

450469

451470
async def analyze_query_indexes_tool(
452-
sql_driver: Union[SqlDriver, SafeSqlDriver],
453471
queries: list[str],
454472
max_index_size_mb: int = 10000,
455473
method: Literal["dta", "llm"] = "dta",
@@ -461,6 +479,7 @@ async def analyze_query_indexes_tool(
461479
return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.")
462480

463481
try:
482+
sql_driver = await get_sql_driver()
464483
if method == "dta":
465484
index_tuning = DatabaseTuningAdvisor(sql_driver)
466485
else:
@@ -473,20 +492,20 @@ async def analyze_query_indexes_tool(
473492
return format_error_response(str(e))
474493

475494

476-
async def analyze_db_health_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], health_type: str = "all") -> ResponseType:
495+
async def analyze_db_health_tool(health_type: str = "all") -> ResponseType:
477496
"""Analyze database health for specified components."""
478-
health_tool = DatabaseHealthTool(sql_driver)
497+
health_tool = DatabaseHealthTool(await get_sql_driver())
479498
result = await health_tool.health(health_type=health_type)
480499
return format_text_response(result)
481500

482501

483502
async def get_top_queries_tool(
484-
sql_driver: Union[SqlDriver, SafeSqlDriver],
485503
sort_by: str = "resources",
486504
limit: int = 10,
487505
) -> ResponseType:
488506
"""Reports the slowest or most resource-intensive queries."""
489507
try:
508+
sql_driver = await get_sql_driver()
490509
top_queries_tool = TopQueriesCalc(sql_driver=sql_driver)
491510

492511
if sort_by == "resources":
@@ -728,58 +747,49 @@ async def call_tool(
728747
name: str, arguments: dict
729748
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
730749
try:
731-
database_url = get_database_url()
732-
async with create_sql_driver(database_url) as sql_driver:
733-
if name == "list_schemas":
734-
return await list_schemas_tool(sql_driver)
735-
elif name == "list_objects":
736-
return await list_objects_tool(
737-
sql_driver,
738-
schema_name=arguments.get("schema_name", ""),
739-
object_type=arguments.get("object_type", "table"),
740-
)
741-
elif name == "get_object_details":
742-
return await get_object_details_tool(
743-
sql_driver,
744-
schema_name=arguments.get("schema_name", ""),
745-
object_name=arguments.get("object_name", ""),
746-
object_type=arguments.get("object_type", "table"),
747-
)
748-
elif name == "explain_query":
749-
return await explain_query_tool(
750-
sql_driver,
751-
sql=arguments.get("sql", ""),
752-
analyze=arguments.get("analyze", False),
753-
hypothetical_indexes=arguments.get("hypothetical_indexes", []),
754-
)
755-
elif name == "execute_sql":
756-
return await execute_sql_tool(sql_driver, sql=arguments.get("sql", ""))
757-
elif name == "analyze_workload_indexes":
758-
return await analyze_workload_indexes_tool(
759-
sql_driver,
760-
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
761-
method=arguments.get("method", "dta"),
762-
)
763-
elif name == "analyze_query_indexes":
764-
return await analyze_query_indexes_tool(
765-
sql_driver,
766-
queries=arguments.get("queries", []),
767-
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
768-
method=arguments.get("method", "dta"),
769-
)
770-
elif name == "analyze_db_health":
771-
return await analyze_db_health_tool(
772-
sql_driver,
773-
health_type=arguments.get("health_type", "all"),
774-
)
775-
elif name == "get_top_queries":
776-
return await get_top_queries_tool(
777-
sql_driver,
778-
sort_by=arguments.get("sort_by", "resources"),
779-
limit=arguments.get("limit", 10),
780-
)
781-
else:
782-
return [types.TextContent(type="text", text=f"Unknown tool: {name}")]
750+
if name == "list_schemas":
751+
return await list_schemas_tool()
752+
elif name == "list_objects":
753+
return await list_objects_tool(
754+
schema_name=arguments.get("schema_name", ""),
755+
object_type=arguments.get("object_type", "table"),
756+
)
757+
elif name == "get_object_details":
758+
return await get_object_details_tool(
759+
schema_name=arguments.get("schema_name", ""),
760+
object_name=arguments.get("object_name", ""),
761+
object_type=arguments.get("object_type", "table"),
762+
)
763+
elif name == "explain_query":
764+
return await explain_query_tool(
765+
sql=arguments.get("sql", ""),
766+
analyze=arguments.get("analyze", False),
767+
hypothetical_indexes=arguments.get("hypothetical_indexes", []),
768+
)
769+
elif name == "execute_sql":
770+
return await execute_sql_tool(sql=arguments.get("sql", ""))
771+
elif name == "analyze_workload_indexes":
772+
return await analyze_workload_indexes_tool(
773+
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
774+
method=arguments.get("method", "dta"),
775+
)
776+
elif name == "analyze_query_indexes":
777+
return await analyze_query_indexes_tool(
778+
queries=arguments.get("queries", []),
779+
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
780+
method=arguments.get("method", "dta"),
781+
)
782+
elif name == "analyze_db_health":
783+
return await analyze_db_health_tool(
784+
health_type=arguments.get("health_type", "all"),
785+
)
786+
elif name == "get_top_queries":
787+
return await get_top_queries_tool(
788+
sort_by=arguments.get("sort_by", "resources"),
789+
limit=arguments.get("limit", 10),
790+
)
791+
else:
792+
return [types.TextContent(type="text", text=f"Unknown tool: {name}")]
783793
except Exception as e:
784794
logger.exception(f"Error executing tool {name}: {e}")
785795
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
@@ -902,6 +912,14 @@ async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]:
902912
yield
903913
finally:
904914
logger.info("Application shutting down...")
915+
# Close all database connections
916+
for _, conn in db_connections.items():
917+
try:
918+
await conn.close()
919+
logger.info("Closed database connection")
920+
except Exception as e:
921+
logger.error(f"Error closing database connection: {e}")
922+
db_connections.clear()
905923

906924
# Create an ASGI application with routes for both transports
907925
starlette_app = Starlette(
@@ -938,6 +956,15 @@ async def shutdown(sig=None):
938956
if sig:
939957
logger.info(f"Received exit signal {sig.name}")
940958

959+
# Close all database connections
960+
for _, conn in db_connections.items():
961+
try:
962+
await conn.close()
963+
logger.info("Closed database connection")
964+
except Exception as e:
965+
logger.error(f"Error closing database connection: {e}")
966+
db_connections.clear()
967+
941968
sys.exit(128 + sig if sig is not None else 0)
942969

943970

0 commit comments

Comments
 (0)