@@ -58,7 +58,6 @@ class AccessMode(str, Enum):
5858
5959
6060# Global variables
61- db_connections : dict [str , DbConnPool ] = {} # Cache connections by database URL
6261current_access_mode = AccessMode .UNRESTRICTED
6362shutdown_in_progress = False
6463
@@ -108,39 +107,25 @@ def get_database_url() -> str:
108107 raise RuntimeError ("Database URL not found in request context" )
109108
110109
111- def get_database_url_or_empty () -> str :
112- """Get the database URL from context or return empty string."""
113- try :
114- return database_url_context .get ()
115- except LookupError :
116- return ""
117-
118-
119- async def get_db_connection (database_url : str ) -> DbConnPool :
120- """Get or create a database connection pool for the given URL."""
121- if database_url not in db_connections :
122- conn = DbConnPool ()
123- await conn .pool_connect (database_url )
124- db_connections [database_url ] = conn
125- logger .info ("Created new connection pool for database" )
126- return db_connections [database_url ]
127-
128-
129- async def get_sql_driver () -> Union [SqlDriver , SafeSqlDriver ]:
130- """Get the appropriate SQL driver based on the current access mode and request context."""
131- database_url = get_database_url ()
110+ @contextlib .asynccontextmanager
111+ async def create_sql_driver (database_url : str ) -> AsyncIterator [Union [SqlDriver , SafeSqlDriver ]]:
112+ """Create a fresh database connection for one tool execution and close it when done."""
132113 if not database_url :
133114 raise ValueError ("No database URL provided in request. Please include database URL in x-auth-data header." )
134115
135- db_connection = await get_db_connection (database_url )
136- base_driver = SqlDriver (conn = db_connection )
137-
138- if current_access_mode == AccessMode .RESTRICTED :
139- logger .debug ("Using SafeSqlDriver with restrictions (RESTRICTED mode)" )
140- return SafeSqlDriver (sql_driver = base_driver , timeout = 30 ) # 30 second timeout
141- else :
142- logger .debug ("Using unrestricted SqlDriver (UNRESTRICTED mode)" )
143- return base_driver
116+ conn = DbConnPool ()
117+ await conn .pool_connect (database_url )
118+ try :
119+ base_driver = SqlDriver (conn = conn )
120+ if current_access_mode == AccessMode .RESTRICTED :
121+ logger .debug ("Using SafeSqlDriver with restrictions (RESTRICTED mode)" )
122+ yield SafeSqlDriver (sql_driver = base_driver , timeout = 30 )
123+ else :
124+ logger .debug ("Using unrestricted SqlDriver (UNRESTRICTED mode)" )
125+ yield base_driver
126+ finally :
127+ await conn .close ()
128+ logger .debug ("Database connection closed after tool execution" )
144129
145130
146131def format_text_response (text : Any ) -> ResponseType :
@@ -154,10 +139,9 @@ def format_error_response(error: str) -> ResponseType:
154139
155140
156141# Tool implementations
157- async def list_schemas_tool () -> ResponseType :
142+ async def list_schemas_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ] ) -> ResponseType :
158143 """List all schemas in the database."""
159144 try :
160- sql_driver = await get_sql_driver ()
161145 rows = await sql_driver .execute_query (
162146 """
163147 SELECT
@@ -180,13 +164,12 @@ async def list_schemas_tool() -> ResponseType:
180164
181165
182166async def list_objects_tool (
167+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
183168 schema_name : str ,
184169 object_type : str = "table" ,
185170) -> ResponseType :
186171 """List objects of a given type in a schema."""
187172 try :
188- sql_driver = await get_sql_driver ()
189-
190173 if object_type in ("table" , "view" ):
191174 table_type = "BASE TABLE" if object_type == "table" else "VIEW"
192175 rows = await SafeSqlDriver .execute_param_query (
@@ -247,14 +230,13 @@ async def list_objects_tool(
247230
248231
249232async def get_object_details_tool (
233+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
250234 schema_name : str ,
251235 object_name : str ,
252236 object_type : str = "table" ,
253237) -> ResponseType :
254238 """Get detailed information about a database object."""
255239 try :
256- sql_driver = await get_sql_driver ()
257-
258240 if object_type in ("table" , "view" ):
259241 # Get columns
260242 col_rows = await SafeSqlDriver .execute_param_query (
@@ -379,6 +361,7 @@ async def get_object_details_tool(
379361
380362
381363async def explain_query_tool (
364+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
382365 sql : str ,
383366 analyze : bool = False ,
384367 hypothetical_indexes : list [dict [str , Any ]] | None = None ,
@@ -387,7 +370,6 @@ async def explain_query_tool(
387370 if hypothetical_indexes is None :
388371 hypothetical_indexes = []
389372 try :
390- sql_driver = await get_sql_driver ()
391373 explain_tool = ExplainPlanTool (sql_driver = sql_driver )
392374 result : ExplainPlanArtifact | ErrorResult | None = None
393375
@@ -435,10 +417,9 @@ async def explain_query_tool(
435417 return format_error_response (str (e ))
436418
437419
438- async def execute_sql_tool (sql : str ) -> ResponseType :
420+ async def execute_sql_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ], sql : str ) -> ResponseType :
439421 """Executes a SQL query against the database."""
440422 try :
441- sql_driver = await get_sql_driver ()
442423 rows = await sql_driver .execute_query (sql ) # type: ignore
443424 if rows is None :
444425 return format_text_response ("No results" )
@@ -449,12 +430,12 @@ async def execute_sql_tool(sql: str) -> ResponseType:
449430
450431
451432async def analyze_workload_indexes_tool (
433+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
452434 max_index_size_mb : int = 10000 ,
453435 method : Literal ["dta" , "llm" ] = "dta" ,
454436) -> ResponseType :
455437 """Analyze frequently executed queries in the database and recommend optimal indexes."""
456438 try :
457- sql_driver = await get_sql_driver ()
458439 if method == "dta" :
459440 index_tuning = DatabaseTuningAdvisor (sql_driver )
460441 else :
@@ -468,6 +449,7 @@ async def analyze_workload_indexes_tool(
468449
469450
470451async def analyze_query_indexes_tool (
452+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
471453 queries : list [str ],
472454 max_index_size_mb : int = 10000 ,
473455 method : Literal ["dta" , "llm" ] = "dta" ,
@@ -479,7 +461,6 @@ async def analyze_query_indexes_tool(
479461 return format_error_response (f"Please provide a list of up to { MAX_NUM_INDEX_TUNING_QUERIES } queries to analyze." )
480462
481463 try :
482- sql_driver = await get_sql_driver ()
483464 if method == "dta" :
484465 index_tuning = DatabaseTuningAdvisor (sql_driver )
485466 else :
@@ -492,20 +473,20 @@ async def analyze_query_indexes_tool(
492473 return format_error_response (str (e ))
493474
494475
495- async def analyze_db_health_tool (health_type : str = "all" ) -> ResponseType :
476+ async def analyze_db_health_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ], health_type : str = "all" ) -> ResponseType :
496477 """Analyze database health for specified components."""
497- health_tool = DatabaseHealthTool (await get_sql_driver () )
478+ health_tool = DatabaseHealthTool (sql_driver )
498479 result = await health_tool .health (health_type = health_type )
499480 return format_text_response (result )
500481
501482
502483async def get_top_queries_tool (
484+ sql_driver : Union [SqlDriver , SafeSqlDriver ],
503485 sort_by : str = "resources" ,
504486 limit : int = 10 ,
505487) -> ResponseType :
506488 """Reports the slowest or most resource-intensive queries."""
507489 try :
508- sql_driver = await get_sql_driver ()
509490 top_queries_tool = TopQueriesCalc (sql_driver = sql_driver )
510491
511492 if sort_by == "resources" :
@@ -747,49 +728,58 @@ async def call_tool(
747728 name : str , arguments : dict
748729 ) -> list [types .TextContent | types .ImageContent | types .EmbeddedResource ]:
749730 try :
750- if name == "list_schemas" :
751- return await list_schemas_tool ()
752- elif name == "list_objects" :
753- return await list_objects_tool (
754- schema_name = arguments .get ("schema_name" , "" ),
755- object_type = arguments .get ("object_type" , "table" ),
756- )
757- elif name == "get_object_details" :
758- return await get_object_details_tool (
759- schema_name = arguments .get ("schema_name" , "" ),
760- object_name = arguments .get ("object_name" , "" ),
761- object_type = arguments .get ("object_type" , "table" ),
762- )
763- elif name == "explain_query" :
764- return await explain_query_tool (
765- sql = arguments .get ("sql" , "" ),
766- analyze = arguments .get ("analyze" , False ),
767- hypothetical_indexes = arguments .get ("hypothetical_indexes" , []),
768- )
769- elif name == "execute_sql" :
770- return await execute_sql_tool (sql = arguments .get ("sql" , "" ))
771- elif name == "analyze_workload_indexes" :
772- return await analyze_workload_indexes_tool (
773- max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
774- method = arguments .get ("method" , "dta" ),
775- )
776- elif name == "analyze_query_indexes" :
777- return await analyze_query_indexes_tool (
778- queries = arguments .get ("queries" , []),
779- max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
780- method = arguments .get ("method" , "dta" ),
781- )
782- elif name == "analyze_db_health" :
783- return await analyze_db_health_tool (
784- health_type = arguments .get ("health_type" , "all" ),
785- )
786- elif name == "get_top_queries" :
787- return await get_top_queries_tool (
788- sort_by = arguments .get ("sort_by" , "resources" ),
789- limit = arguments .get ("limit" , 10 ),
790- )
791- else :
792- return [types .TextContent (type = "text" , text = f"Unknown tool: { name } " )]
731+ database_url = get_database_url ()
732+ async with create_sql_driver (database_url ) as sql_driver :
733+ if name == "list_schemas" :
734+ return await list_schemas_tool (sql_driver )
735+ elif name == "list_objects" :
736+ return await list_objects_tool (
737+ sql_driver ,
738+ schema_name = arguments .get ("schema_name" , "" ),
739+ object_type = arguments .get ("object_type" , "table" ),
740+ )
741+ elif name == "get_object_details" :
742+ return await get_object_details_tool (
743+ sql_driver ,
744+ schema_name = arguments .get ("schema_name" , "" ),
745+ object_name = arguments .get ("object_name" , "" ),
746+ object_type = arguments .get ("object_type" , "table" ),
747+ )
748+ elif name == "explain_query" :
749+ return await explain_query_tool (
750+ sql_driver ,
751+ sql = arguments .get ("sql" , "" ),
752+ analyze = arguments .get ("analyze" , False ),
753+ hypothetical_indexes = arguments .get ("hypothetical_indexes" , []),
754+ )
755+ elif name == "execute_sql" :
756+ return await execute_sql_tool (sql_driver , sql = arguments .get ("sql" , "" ))
757+ elif name == "analyze_workload_indexes" :
758+ return await analyze_workload_indexes_tool (
759+ sql_driver ,
760+ max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
761+ method = arguments .get ("method" , "dta" ),
762+ )
763+ elif name == "analyze_query_indexes" :
764+ return await analyze_query_indexes_tool (
765+ sql_driver ,
766+ queries = arguments .get ("queries" , []),
767+ max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
768+ method = arguments .get ("method" , "dta" ),
769+ )
770+ elif name == "analyze_db_health" :
771+ return await analyze_db_health_tool (
772+ sql_driver ,
773+ health_type = arguments .get ("health_type" , "all" ),
774+ )
775+ elif name == "get_top_queries" :
776+ return await get_top_queries_tool (
777+ sql_driver ,
778+ sort_by = arguments .get ("sort_by" , "resources" ),
779+ limit = arguments .get ("limit" , 10 ),
780+ )
781+ else :
782+ return [types .TextContent (type = "text" , text = f"Unknown tool: { name } " )]
793783 except Exception as e :
794784 logger .exception (f"Error executing tool { name } : { e } " )
795785 return [types .TextContent (type = "text" , text = f"Error: { str (e )} " )]
@@ -912,14 +902,6 @@ async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]:
912902 yield
913903 finally :
914904 logger .info ("Application shutting down..." )
915- # Close all database connections
916- for _ , conn in db_connections .items ():
917- try :
918- await conn .close ()
919- logger .info ("Closed database connection" )
920- except Exception as e :
921- logger .error (f"Error closing database connection: { e } " )
922- db_connections .clear ()
923905
924906 # Create an ASGI application with routes for both transports
925907 starlette_app = Starlette (
@@ -956,15 +938,6 @@ async def shutdown(sig=None):
956938 if sig :
957939 logger .info (f"Received exit signal { sig .name } " )
958940
959- # Close all database connections
960- for _ , conn in db_connections .items ():
961- try :
962- await conn .close ()
963- logger .info ("Closed database connection" )
964- except Exception as e :
965- logger .error (f"Error closing database connection: { e } " )
966- db_connections .clear ()
967-
968941 sys .exit (128 + sig if sig is not None else 0 )
969942
970943
0 commit comments