Skip to content

Commit 365115e

Browse files
authored
Remove connection cache for postgres (#1333)
1 parent a1efdea commit 365115e

File tree

1 file changed

+78
-105
lines changed
  • mcp_servers/postgres/src/postgres_mcp

1 file changed

+78
-105
lines changed

mcp_servers/postgres/src/postgres_mcp/server.py

Lines changed: 78 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class AccessMode(str, Enum):
5858

5959

6060
# Global variables
61-
db_connections: dict[str, DbConnPool] = {} # Cache connections by database URL
6261
current_access_mode = AccessMode.UNRESTRICTED
6362
shutdown_in_progress = False
6463

@@ -108,39 +107,25 @@ def get_database_url() -> str:
108107
raise RuntimeError("Database URL not found in request context")
109108

110109

111-
def get_database_url_or_empty() -> str:
112-
"""Get the database URL from context or return empty string."""
113-
try:
114-
return database_url_context.get()
115-
except LookupError:
116-
return ""
117-
118-
119-
async def get_db_connection(database_url: str) -> DbConnPool:
120-
"""Get or create a database connection pool for the given URL."""
121-
if database_url not in db_connections:
122-
conn = DbConnPool()
123-
await conn.pool_connect(database_url)
124-
db_connections[database_url] = conn
125-
logger.info("Created new connection pool for database")
126-
return db_connections[database_url]
127-
128-
129-
async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:
130-
"""Get the appropriate SQL driver based on the current access mode and request context."""
131-
database_url = get_database_url()
110+
@contextlib.asynccontextmanager
111+
async def create_sql_driver(database_url: str) -> AsyncIterator[Union[SqlDriver, SafeSqlDriver]]:
112+
"""Create a fresh database connection for one tool execution and close it when done."""
132113
if not database_url:
133114
raise ValueError("No database URL provided in request. Please include database URL in x-auth-data header.")
134115

135-
db_connection = await get_db_connection(database_url)
136-
base_driver = SqlDriver(conn=db_connection)
137-
138-
if current_access_mode == AccessMode.RESTRICTED:
139-
logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)")
140-
return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout
141-
else:
142-
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
143-
return base_driver
116+
conn = DbConnPool()
117+
await conn.pool_connect(database_url)
118+
try:
119+
base_driver = SqlDriver(conn=conn)
120+
if current_access_mode == AccessMode.RESTRICTED:
121+
logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)")
122+
yield SafeSqlDriver(sql_driver=base_driver, timeout=30)
123+
else:
124+
logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)")
125+
yield base_driver
126+
finally:
127+
await conn.close()
128+
logger.debug("Database connection closed after tool execution")
144129

145130

146131
def format_text_response(text: Any) -> ResponseType:
@@ -154,10 +139,9 @@ def format_error_response(error: str) -> ResponseType:
154139

155140

156141
# Tool implementations
157-
async def list_schemas_tool() -> ResponseType:
142+
async def list_schemas_tool(sql_driver: Union[SqlDriver, SafeSqlDriver]) -> ResponseType:
158143
"""List all schemas in the database."""
159144
try:
160-
sql_driver = await get_sql_driver()
161145
rows = await sql_driver.execute_query(
162146
"""
163147
SELECT
@@ -180,13 +164,12 @@ async def list_schemas_tool() -> ResponseType:
180164

181165

182166
async def list_objects_tool(
167+
sql_driver: Union[SqlDriver, SafeSqlDriver],
183168
schema_name: str,
184169
object_type: str = "table",
185170
) -> ResponseType:
186171
"""List objects of a given type in a schema."""
187172
try:
188-
sql_driver = await get_sql_driver()
189-
190173
if object_type in ("table", "view"):
191174
table_type = "BASE TABLE" if object_type == "table" else "VIEW"
192175
rows = await SafeSqlDriver.execute_param_query(
@@ -247,14 +230,13 @@ async def list_objects_tool(
247230

248231

249232
async def get_object_details_tool(
233+
sql_driver: Union[SqlDriver, SafeSqlDriver],
250234
schema_name: str,
251235
object_name: str,
252236
object_type: str = "table",
253237
) -> ResponseType:
254238
"""Get detailed information about a database object."""
255239
try:
256-
sql_driver = await get_sql_driver()
257-
258240
if object_type in ("table", "view"):
259241
# Get columns
260242
col_rows = await SafeSqlDriver.execute_param_query(
@@ -379,6 +361,7 @@ async def get_object_details_tool(
379361

380362

381363
async def explain_query_tool(
364+
sql_driver: Union[SqlDriver, SafeSqlDriver],
382365
sql: str,
383366
analyze: bool = False,
384367
hypothetical_indexes: list[dict[str, Any]] | None = None,
@@ -387,7 +370,6 @@ async def explain_query_tool(
387370
if hypothetical_indexes is None:
388371
hypothetical_indexes = []
389372
try:
390-
sql_driver = await get_sql_driver()
391373
explain_tool = ExplainPlanTool(sql_driver=sql_driver)
392374
result: ExplainPlanArtifact | ErrorResult | None = None
393375

@@ -435,10 +417,9 @@ async def explain_query_tool(
435417
return format_error_response(str(e))
436418

437419

438-
async def execute_sql_tool(sql: str) -> ResponseType:
420+
async def execute_sql_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], sql: str) -> ResponseType:
439421
"""Executes a SQL query against the database."""
440422
try:
441-
sql_driver = await get_sql_driver()
442423
rows = await sql_driver.execute_query(sql) # type: ignore
443424
if rows is None:
444425
return format_text_response("No results")
@@ -449,12 +430,12 @@ async def execute_sql_tool(sql: str) -> ResponseType:
449430

450431

451432
async def analyze_workload_indexes_tool(
433+
sql_driver: Union[SqlDriver, SafeSqlDriver],
452434
max_index_size_mb: int = 10000,
453435
method: Literal["dta", "llm"] = "dta",
454436
) -> ResponseType:
455437
"""Analyze frequently executed queries in the database and recommend optimal indexes."""
456438
try:
457-
sql_driver = await get_sql_driver()
458439
if method == "dta":
459440
index_tuning = DatabaseTuningAdvisor(sql_driver)
460441
else:
@@ -468,6 +449,7 @@ async def analyze_workload_indexes_tool(
468449

469450

470451
async def analyze_query_indexes_tool(
452+
sql_driver: Union[SqlDriver, SafeSqlDriver],
471453
queries: list[str],
472454
max_index_size_mb: int = 10000,
473455
method: Literal["dta", "llm"] = "dta",
@@ -479,7 +461,6 @@ async def analyze_query_indexes_tool(
479461
return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.")
480462

481463
try:
482-
sql_driver = await get_sql_driver()
483464
if method == "dta":
484465
index_tuning = DatabaseTuningAdvisor(sql_driver)
485466
else:
@@ -492,20 +473,20 @@ async def analyze_query_indexes_tool(
492473
return format_error_response(str(e))
493474

494475

495-
async def analyze_db_health_tool(health_type: str = "all") -> ResponseType:
476+
async def analyze_db_health_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], health_type: str = "all") -> ResponseType:
496477
"""Analyze database health for specified components."""
497-
health_tool = DatabaseHealthTool(await get_sql_driver())
478+
health_tool = DatabaseHealthTool(sql_driver)
498479
result = await health_tool.health(health_type=health_type)
499480
return format_text_response(result)
500481

501482

502483
async def get_top_queries_tool(
484+
sql_driver: Union[SqlDriver, SafeSqlDriver],
503485
sort_by: str = "resources",
504486
limit: int = 10,
505487
) -> ResponseType:
506488
"""Reports the slowest or most resource-intensive queries."""
507489
try:
508-
sql_driver = await get_sql_driver()
509490
top_queries_tool = TopQueriesCalc(sql_driver=sql_driver)
510491

511492
if sort_by == "resources":
@@ -747,49 +728,58 @@ async def call_tool(
747728
name: str, arguments: dict
748729
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
749730
try:
750-
if name == "list_schemas":
751-
return await list_schemas_tool()
752-
elif name == "list_objects":
753-
return await list_objects_tool(
754-
schema_name=arguments.get("schema_name", ""),
755-
object_type=arguments.get("object_type", "table"),
756-
)
757-
elif name == "get_object_details":
758-
return await get_object_details_tool(
759-
schema_name=arguments.get("schema_name", ""),
760-
object_name=arguments.get("object_name", ""),
761-
object_type=arguments.get("object_type", "table"),
762-
)
763-
elif name == "explain_query":
764-
return await explain_query_tool(
765-
sql=arguments.get("sql", ""),
766-
analyze=arguments.get("analyze", False),
767-
hypothetical_indexes=arguments.get("hypothetical_indexes", []),
768-
)
769-
elif name == "execute_sql":
770-
return await execute_sql_tool(sql=arguments.get("sql", ""))
771-
elif name == "analyze_workload_indexes":
772-
return await analyze_workload_indexes_tool(
773-
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
774-
method=arguments.get("method", "dta"),
775-
)
776-
elif name == "analyze_query_indexes":
777-
return await analyze_query_indexes_tool(
778-
queries=arguments.get("queries", []),
779-
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
780-
method=arguments.get("method", "dta"),
781-
)
782-
elif name == "analyze_db_health":
783-
return await analyze_db_health_tool(
784-
health_type=arguments.get("health_type", "all"),
785-
)
786-
elif name == "get_top_queries":
787-
return await get_top_queries_tool(
788-
sort_by=arguments.get("sort_by", "resources"),
789-
limit=arguments.get("limit", 10),
790-
)
791-
else:
792-
return [types.TextContent(type="text", text=f"Unknown tool: {name}")]
731+
database_url = get_database_url()
732+
async with create_sql_driver(database_url) as sql_driver:
733+
if name == "list_schemas":
734+
return await list_schemas_tool(sql_driver)
735+
elif name == "list_objects":
736+
return await list_objects_tool(
737+
sql_driver,
738+
schema_name=arguments.get("schema_name", ""),
739+
object_type=arguments.get("object_type", "table"),
740+
)
741+
elif name == "get_object_details":
742+
return await get_object_details_tool(
743+
sql_driver,
744+
schema_name=arguments.get("schema_name", ""),
745+
object_name=arguments.get("object_name", ""),
746+
object_type=arguments.get("object_type", "table"),
747+
)
748+
elif name == "explain_query":
749+
return await explain_query_tool(
750+
sql_driver,
751+
sql=arguments.get("sql", ""),
752+
analyze=arguments.get("analyze", False),
753+
hypothetical_indexes=arguments.get("hypothetical_indexes", []),
754+
)
755+
elif name == "execute_sql":
756+
return await execute_sql_tool(sql_driver, sql=arguments.get("sql", ""))
757+
elif name == "analyze_workload_indexes":
758+
return await analyze_workload_indexes_tool(
759+
sql_driver,
760+
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
761+
method=arguments.get("method", "dta"),
762+
)
763+
elif name == "analyze_query_indexes":
764+
return await analyze_query_indexes_tool(
765+
sql_driver,
766+
queries=arguments.get("queries", []),
767+
max_index_size_mb=arguments.get("max_index_size_mb", 10000),
768+
method=arguments.get("method", "dta"),
769+
)
770+
elif name == "analyze_db_health":
771+
return await analyze_db_health_tool(
772+
sql_driver,
773+
health_type=arguments.get("health_type", "all"),
774+
)
775+
elif name == "get_top_queries":
776+
return await get_top_queries_tool(
777+
sql_driver,
778+
sort_by=arguments.get("sort_by", "resources"),
779+
limit=arguments.get("limit", 10),
780+
)
781+
else:
782+
return [types.TextContent(type="text", text=f"Unknown tool: {name}")]
793783
except Exception as e:
794784
logger.exception(f"Error executing tool {name}: {e}")
795785
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
@@ -912,14 +902,6 @@ async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]:
912902
yield
913903
finally:
914904
logger.info("Application shutting down...")
915-
# Close all database connections
916-
for _, conn in db_connections.items():
917-
try:
918-
await conn.close()
919-
logger.info("Closed database connection")
920-
except Exception as e:
921-
logger.error(f"Error closing database connection: {e}")
922-
db_connections.clear()
923905

924906
# Create an ASGI application with routes for both transports
925907
starlette_app = Starlette(
@@ -956,15 +938,6 @@ async def shutdown(sig=None):
956938
if sig:
957939
logger.info(f"Received exit signal {sig.name}")
958940

959-
# Close all database connections
960-
for _, conn in db_connections.items():
961-
try:
962-
await conn.close()
963-
logger.info("Closed database connection")
964-
except Exception as e:
965-
logger.error(f"Error closing database connection: {e}")
966-
db_connections.clear()
967-
968941
sys.exit(128 + sig if sig is not None else 0)
969942

970943

0 commit comments

Comments
 (0)