@@ -58,6 +58,7 @@ class AccessMode(str, Enum):
5858
5959
6060# Global variables
61+ db_connections : dict [str , DbConnPool ] = {} # Cache connections by database URL
6162current_access_mode = AccessMode .UNRESTRICTED
6263shutdown_in_progress = False
6364
@@ -107,25 +108,39 @@ def get_database_url() -> str:
107108 raise RuntimeError ("Database URL not found in request context" )
108109
109110
110- @contextlib .asynccontextmanager
111- async def create_sql_driver (database_url : str ) -> AsyncIterator [Union [SqlDriver , SafeSqlDriver ]]:
112- """Create a fresh database connection for one tool execution and close it when done."""
111+ def get_database_url_or_empty () -> str :
112+ """Get the database URL from context or return empty string."""
113+ try :
114+ return database_url_context .get ()
115+ except LookupError :
116+ return ""
117+
118+
119+ async def get_db_connection (database_url : str ) -> DbConnPool :
120+ """Get or create a database connection pool for the given URL."""
121+ if database_url not in db_connections :
122+ conn = DbConnPool ()
123+ await conn .pool_connect (database_url )
124+ db_connections [database_url ] = conn
125+ logger .info ("Created new connection pool for database" )
126+ return db_connections [database_url ]
127+
128+
129+ async def get_sql_driver () -> Union [SqlDriver , SafeSqlDriver ]:
130+ """Get the appropriate SQL driver based on the current access mode and request context."""
131+ database_url = get_database_url ()
113132 if not database_url :
114133 raise ValueError ("No database URL provided in request. Please include database URL in x-auth-data header." )
115134
116- conn = DbConnPool ()
117- await conn .pool_connect (database_url )
118- try :
119- base_driver = SqlDriver (conn = conn )
120- if current_access_mode == AccessMode .RESTRICTED :
121- logger .debug ("Using SafeSqlDriver with restrictions (RESTRICTED mode)" )
122- yield SafeSqlDriver (sql_driver = base_driver , timeout = 30 )
123- else :
124- logger .debug ("Using unrestricted SqlDriver (UNRESTRICTED mode)" )
125- yield base_driver
126- finally :
127- await conn .close ()
128- logger .debug ("Database connection closed after tool execution" )
135+ db_connection = await get_db_connection (database_url )
136+ base_driver = SqlDriver (conn = db_connection )
137+
138+ if current_access_mode == AccessMode .RESTRICTED :
139+ logger .debug ("Using SafeSqlDriver with restrictions (RESTRICTED mode)" )
140+ return SafeSqlDriver (sql_driver = base_driver , timeout = 30 ) # 30 second timeout
141+ else :
142+ logger .debug ("Using unrestricted SqlDriver (UNRESTRICTED mode)" )
143+ return base_driver
129144
130145
131146def format_text_response (text : Any ) -> ResponseType :
@@ -139,9 +154,10 @@ def format_error_response(error: str) -> ResponseType:
139154
140155
141156# Tool implementations
142- async def list_schemas_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ] ) -> ResponseType :
157+ async def list_schemas_tool () -> ResponseType :
143158 """List all schemas in the database."""
144159 try :
160+ sql_driver = await get_sql_driver ()
145161 rows = await sql_driver .execute_query (
146162 """
147163 SELECT
@@ -164,12 +180,13 @@ async def list_schemas_tool(sql_driver: Union[SqlDriver, SafeSqlDriver]) -> Resp
164180
165181
166182async def list_objects_tool (
167- sql_driver : Union [SqlDriver , SafeSqlDriver ],
168183 schema_name : str ,
169184 object_type : str = "table" ,
170185) -> ResponseType :
171186 """List objects of a given type in a schema."""
172187 try :
188+ sql_driver = await get_sql_driver ()
189+
173190 if object_type in ("table" , "view" ):
174191 table_type = "BASE TABLE" if object_type == "table" else "VIEW"
175192 rows = await SafeSqlDriver .execute_param_query (
@@ -230,13 +247,14 @@ async def list_objects_tool(
230247
231248
232249async def get_object_details_tool (
233- sql_driver : Union [SqlDriver , SafeSqlDriver ],
234250 schema_name : str ,
235251 object_name : str ,
236252 object_type : str = "table" ,
237253) -> ResponseType :
238254 """Get detailed information about a database object."""
239255 try :
256+ sql_driver = await get_sql_driver ()
257+
240258 if object_type in ("table" , "view" ):
241259 # Get columns
242260 col_rows = await SafeSqlDriver .execute_param_query (
@@ -361,7 +379,6 @@ async def get_object_details_tool(
361379
362380
363381async def explain_query_tool (
364- sql_driver : Union [SqlDriver , SafeSqlDriver ],
365382 sql : str ,
366383 analyze : bool = False ,
367384 hypothetical_indexes : list [dict [str , Any ]] | None = None ,
@@ -370,6 +387,7 @@ async def explain_query_tool(
370387 if hypothetical_indexes is None :
371388 hypothetical_indexes = []
372389 try :
390+ sql_driver = await get_sql_driver ()
373391 explain_tool = ExplainPlanTool (sql_driver = sql_driver )
374392 result : ExplainPlanArtifact | ErrorResult | None = None
375393
@@ -417,9 +435,10 @@ async def explain_query_tool(
417435 return format_error_response (str (e ))
418436
419437
420- async def execute_sql_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ], sql : str ) -> ResponseType :
438+ async def execute_sql_tool (sql : str ) -> ResponseType :
421439 """Executes a SQL query against the database."""
422440 try :
441+ sql_driver = await get_sql_driver ()
423442 rows = await sql_driver .execute_query (sql ) # type: ignore
424443 if rows is None :
425444 return format_text_response ("No results" )
@@ -430,12 +449,12 @@ async def execute_sql_tool(sql_driver: Union[SqlDriver, SafeSqlDriver], sql: str
430449
431450
432451async def analyze_workload_indexes_tool (
433- sql_driver : Union [SqlDriver , SafeSqlDriver ],
434452 max_index_size_mb : int = 10000 ,
435453 method : Literal ["dta" , "llm" ] = "dta" ,
436454) -> ResponseType :
437455 """Analyze frequently executed queries in the database and recommend optimal indexes."""
438456 try :
457+ sql_driver = await get_sql_driver ()
439458 if method == "dta" :
440459 index_tuning = DatabaseTuningAdvisor (sql_driver )
441460 else :
@@ -449,7 +468,6 @@ async def analyze_workload_indexes_tool(
449468
450469
451470async def analyze_query_indexes_tool (
452- sql_driver : Union [SqlDriver , SafeSqlDriver ],
453471 queries : list [str ],
454472 max_index_size_mb : int = 10000 ,
455473 method : Literal ["dta" , "llm" ] = "dta" ,
@@ -461,6 +479,7 @@ async def analyze_query_indexes_tool(
461479 return format_error_response (f"Please provide a list of up to { MAX_NUM_INDEX_TUNING_QUERIES } queries to analyze." )
462480
463481 try :
482+ sql_driver = await get_sql_driver ()
464483 if method == "dta" :
465484 index_tuning = DatabaseTuningAdvisor (sql_driver )
466485 else :
@@ -473,20 +492,20 @@ async def analyze_query_indexes_tool(
473492 return format_error_response (str (e ))
474493
475494
476- async def analyze_db_health_tool (sql_driver : Union [ SqlDriver , SafeSqlDriver ], health_type : str = "all" ) -> ResponseType :
495+ async def analyze_db_health_tool (health_type : str = "all" ) -> ResponseType :
477496 """Analyze database health for specified components."""
478- health_tool = DatabaseHealthTool (sql_driver )
497+ health_tool = DatabaseHealthTool (await get_sql_driver () )
479498 result = await health_tool .health (health_type = health_type )
480499 return format_text_response (result )
481500
482501
483502async def get_top_queries_tool (
484- sql_driver : Union [SqlDriver , SafeSqlDriver ],
485503 sort_by : str = "resources" ,
486504 limit : int = 10 ,
487505) -> ResponseType :
488506 """Reports the slowest or most resource-intensive queries."""
489507 try :
508+ sql_driver = await get_sql_driver ()
490509 top_queries_tool = TopQueriesCalc (sql_driver = sql_driver )
491510
492511 if sort_by == "resources" :
@@ -728,58 +747,49 @@ async def call_tool(
728747 name : str , arguments : dict
729748 ) -> list [types .TextContent | types .ImageContent | types .EmbeddedResource ]:
730749 try :
731- database_url = get_database_url ()
732- async with create_sql_driver (database_url ) as sql_driver :
733- if name == "list_schemas" :
734- return await list_schemas_tool (sql_driver )
735- elif name == "list_objects" :
736- return await list_objects_tool (
737- sql_driver ,
738- schema_name = arguments .get ("schema_name" , "" ),
739- object_type = arguments .get ("object_type" , "table" ),
740- )
741- elif name == "get_object_details" :
742- return await get_object_details_tool (
743- sql_driver ,
744- schema_name = arguments .get ("schema_name" , "" ),
745- object_name = arguments .get ("object_name" , "" ),
746- object_type = arguments .get ("object_type" , "table" ),
747- )
748- elif name == "explain_query" :
749- return await explain_query_tool (
750- sql_driver ,
751- sql = arguments .get ("sql" , "" ),
752- analyze = arguments .get ("analyze" , False ),
753- hypothetical_indexes = arguments .get ("hypothetical_indexes" , []),
754- )
755- elif name == "execute_sql" :
756- return await execute_sql_tool (sql_driver , sql = arguments .get ("sql" , "" ))
757- elif name == "analyze_workload_indexes" :
758- return await analyze_workload_indexes_tool (
759- sql_driver ,
760- max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
761- method = arguments .get ("method" , "dta" ),
762- )
763- elif name == "analyze_query_indexes" :
764- return await analyze_query_indexes_tool (
765- sql_driver ,
766- queries = arguments .get ("queries" , []),
767- max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
768- method = arguments .get ("method" , "dta" ),
769- )
770- elif name == "analyze_db_health" :
771- return await analyze_db_health_tool (
772- sql_driver ,
773- health_type = arguments .get ("health_type" , "all" ),
774- )
775- elif name == "get_top_queries" :
776- return await get_top_queries_tool (
777- sql_driver ,
778- sort_by = arguments .get ("sort_by" , "resources" ),
779- limit = arguments .get ("limit" , 10 ),
780- )
781- else :
782- return [types .TextContent (type = "text" , text = f"Unknown tool: { name } " )]
750+ if name == "list_schemas" :
751+ return await list_schemas_tool ()
752+ elif name == "list_objects" :
753+ return await list_objects_tool (
754+ schema_name = arguments .get ("schema_name" , "" ),
755+ object_type = arguments .get ("object_type" , "table" ),
756+ )
757+ elif name == "get_object_details" :
758+ return await get_object_details_tool (
759+ schema_name = arguments .get ("schema_name" , "" ),
760+ object_name = arguments .get ("object_name" , "" ),
761+ object_type = arguments .get ("object_type" , "table" ),
762+ )
763+ elif name == "explain_query" :
764+ return await explain_query_tool (
765+ sql = arguments .get ("sql" , "" ),
766+ analyze = arguments .get ("analyze" , False ),
767+ hypothetical_indexes = arguments .get ("hypothetical_indexes" , []),
768+ )
769+ elif name == "execute_sql" :
770+ return await execute_sql_tool (sql = arguments .get ("sql" , "" ))
771+ elif name == "analyze_workload_indexes" :
772+ return await analyze_workload_indexes_tool (
773+ max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
774+ method = arguments .get ("method" , "dta" ),
775+ )
776+ elif name == "analyze_query_indexes" :
777+ return await analyze_query_indexes_tool (
778+ queries = arguments .get ("queries" , []),
779+ max_index_size_mb = arguments .get ("max_index_size_mb" , 10000 ),
780+ method = arguments .get ("method" , "dta" ),
781+ )
782+ elif name == "analyze_db_health" :
783+ return await analyze_db_health_tool (
784+ health_type = arguments .get ("health_type" , "all" ),
785+ )
786+ elif name == "get_top_queries" :
787+ return await get_top_queries_tool (
788+ sort_by = arguments .get ("sort_by" , "resources" ),
789+ limit = arguments .get ("limit" , 10 ),
790+ )
791+ else :
792+ return [types .TextContent (type = "text" , text = f"Unknown tool: { name } " )]
783793 except Exception as e :
784794 logger .exception (f"Error executing tool { name } : { e } " )
785795 return [types .TextContent (type = "text" , text = f"Error: { str (e )} " )]
@@ -902,6 +912,14 @@ async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]:
902912 yield
903913 finally :
904914 logger .info ("Application shutting down..." )
915+ # Close all database connections
916+ for _ , conn in db_connections .items ():
917+ try :
918+ await conn .close ()
919+ logger .info ("Closed database connection" )
920+ except Exception as e :
921+ logger .error (f"Error closing database connection: { e } " )
922+ db_connections .clear ()
905923
906924 # Create an ASGI application with routes for both transports
907925 starlette_app = Starlette (
@@ -938,6 +956,15 @@ async def shutdown(sig=None):
938956 if sig :
939957 logger .info (f"Received exit signal { sig .name } " )
940958
959+ # Close all database connections
960+ for _ , conn in db_connections .items ():
961+ try :
962+ await conn .close ()
963+ logger .info ("Closed database connection" )
964+ except Exception as e :
965+ logger .error (f"Error closing database connection: { e } " )
966+ db_connections .clear ()
967+
941968 sys .exit (128 + sig if sig is not None else 0 )
942969
943970
0 commit comments