Skip to content

Commit ec2b124

Browse files
Michael Christensenim-michaelc
authored andcommitted
Set session conneciton to lazily instantiate with async function calls
1 parent baa8278 commit ec2b124

File tree

4 files changed

+105
-75
lines changed

4 files changed

+105
-75
lines changed

src/amazon-keyspaces-mcp-server/awslabs/amazon_keyspaces_mcp_server/client.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from cassandra.auth import PlainTextAuthProvider
2727
from cassandra.cluster import Cluster, Session
28-
from cassandra.io.asyncioreactor import AsyncioConnection
2928

3029
from .consts import (
3130
CERT_DIRECTORY,
@@ -70,19 +69,23 @@ def __init__(self, database_config: DatabaseConfig):
7069
"""Initialize the client with the given configuration."""
7170
self.database_config = database_config
7271
self.is_keyspaces = database_config.use_keyspaces
73-
74-
# Initialize session for the configured database type (Keyspaces or Cassandra)
75-
try:
76-
if self.is_keyspaces:
77-
self.session = self._create_keyspaces_session()
78-
logger.info('Connected to Amazon Keyspaces')
79-
else:
80-
self.session = self._create_cassandra_session()
81-
logger.info('Connected to Cassandra cluster')
82-
except Exception as e:
83-
target = 'Amazon Keyspaces' if self.is_keyspaces else 'Cassandra cluster'
84-
logger.error('Failed to connect to %s: %s', target, str(e))
85-
raise RuntimeError(f'Failed to connect to {target}: {str(e)}') from e
72+
self._session: Optional[Session] = None
73+
74+
async def get_session(self) -> Session:
75+
"""Get or create database session lazily."""
76+
if self._session is None:
77+
try:
78+
if self.is_keyspaces:
79+
self._session = self._create_keyspaces_session()
80+
logger.info('Connected to Amazon Keyspaces')
81+
else:
82+
self._session = self._create_cassandra_session()
83+
logger.info('Connected to Cassandra cluster')
84+
except Exception as e:
85+
target = 'Amazon Keyspaces' if self.is_keyspaces else 'Cassandra cluster'
86+
logger.error('Failed to connect to %s: %s', target, str(e))
87+
raise RuntimeError(f'Failed to connect to {target}: {str(e)}') from e
88+
return self._session
8689

8790
def _create_cassandra_session(self) -> Session:
8891
"""Create a session for Apache Cassandra."""
@@ -102,8 +105,6 @@ def _create_cassandra_session(self) -> Session:
102105
connect_timeout=int(CONNECTION_TIMEOUT),
103106
)
104107

105-
cluster.connection_class = AsyncioConnection
106-
107108
return cluster.connect()
108109

109110
def _create_keyspaces_session(self) -> Session:
@@ -144,8 +145,6 @@ def _create_keyspaces_session(self) -> Session:
144145
connect_timeout=int(CONNECTION_TIMEOUT),
145146
)
146147

147-
cluster.connection_class = AsyncioConnection
148-
149148
return cluster.connect()
150149

151150
def _create_ssl_context_for_keyspaces(self) -> ssl.SSLContext:
@@ -171,13 +170,14 @@ def is_using_keyspaces(self) -> bool:
171170
"""Check if the client is using Amazon Keyspaces."""
172171
return self.is_keyspaces
173172

174-
def list_keyspaces(self) -> List[KeyspaceInfo]:
173+
async def list_keyspaces(self) -> List[KeyspaceInfo]:
175174
"""List all keyspaces in the database."""
176175
keyspaces = []
177176

178177
try:
179178
query = 'SELECT keyspace_name, replication FROM system_schema.keyspaces'
180-
rows = self.session.execute(query)
179+
session = await self.get_session()
180+
rows = session.execute(query)
181181

182182
for row in rows:
183183
name = row.keyspace_name
@@ -199,13 +199,14 @@ def list_keyspaces(self) -> List[KeyspaceInfo]:
199199
logger.error('Error listing keyspaces: %s', str(e))
200200
raise RuntimeError(f'Failed to list keyspaces: {str(e)}') from e
201201

202-
def list_tables(self, keyspace_name: str) -> List[TableInfo]:
202+
async def list_tables(self, keyspace_name: str) -> List[TableInfo]:
203203
"""List all tables in a keyspace."""
204204
tables = []
205205

206206
try:
207207
query = 'SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s'
208-
rows = self.session.execute(query, [keyspace_name])
208+
session = await self.get_session()
209+
rows = session.execute(query, [keyspace_name])
209210

210211
for row in rows:
211212
name = row.table_name
@@ -218,11 +219,13 @@ def list_tables(self, keyspace_name: str) -> List[TableInfo]:
218219
f'Failed to list tables for keyspace {keyspace_name}: {str(e)}'
219220
) from e
220221

221-
def describe_keyspace(self, keyspace_name: str) -> Dict[str, Any]:
222+
async def describe_keyspace(self, keyspace_name: str) -> Dict[str, Any]:
222223
"""Get detailed information about a keyspace."""
223224
try:
224225
query = 'SELECT * FROM system_schema.keyspaces WHERE keyspace_name = %s'
225-
row = self.session.execute(query, [keyspace_name]).one()
226+
session = await self.get_session()
227+
228+
row = session.execute(query, [keyspace_name]).one()
226229

227230
if not row:
228231
raise RuntimeError(f'Keyspace not found: {keyspace_name}')
@@ -245,14 +248,16 @@ def describe_keyspace(self, keyspace_name: str) -> Dict[str, Any]:
245248
logger.error('Error describing keyspace %s: %s', keyspace_name, str(e))
246249
raise RuntimeError(f'Failed to describe keyspace {keyspace_name}: {str(e)}') from e
247250

248-
def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
251+
async def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
249252
"""Get detailed information about a table."""
250253
try:
251254
query = (
252255
'SELECT * FROM system_schema.tables WHERE '
253256
'keyspace_name = %s AND table_name = %s'
254257
)
255-
table_row = self.session.execute(query, [keyspace_name, table_name]).one()
258+
session = await self.get_session()
259+
260+
table_row = session.execute(query, [keyspace_name, table_name]).one()
256261

257262
if not table_row:
258263
raise RuntimeError(f'Table not found: {keyspace_name}.{table_name}')
@@ -266,7 +271,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
266271
query = (
267272
'SELECT * FROM system_schema.columns WHERE keyspace_name = %s AND table_name = %s'
268273
)
269-
column_rows = self.session.execute(query, [keyspace_name, table_name])
274+
session = await self.get_session()
275+
276+
column_rows = session.execute(query, [keyspace_name, table_name])
270277

271278
columns = []
272279
for column_row in column_rows:
@@ -284,7 +291,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
284291
query = (
285292
'SELECT * FROM system_schema.indexes WHERE keyspace_name = %s AND table_name = %s'
286293
)
287-
index_rows = self.session.execute(query, [keyspace_name, table_name])
294+
session = await self.get_session()
295+
296+
index_rows = session.execute(query, [keyspace_name, table_name])
288297

289298
indexes = []
290299
for index_row in index_rows:
@@ -307,7 +316,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
307316
'SELECT custom_properties FROM system_schema_mcs.tables '
308317
'WHERE keyspace_name = %s AND table_name = %s'
309318
)
310-
capacity_row = self.session.execute(
319+
session = await self.get_session()
320+
321+
capacity_row = session.execute(
311322
query, [keyspace_name, table_name]
312323
).one()
313324

@@ -339,7 +350,7 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
339350
f'Failed to describe table {keyspace_name}.{table_name}: {str(e)}'
340351
) from e
341352

342-
def execute_read_only_query(
353+
async def execute_read_only_query(
343354
self, query: str, params: Optional[List[Any]] = None
344355
) -> Dict[str, Any]:
345356
"""Execute a read-only SELECT query against the database."""
@@ -360,9 +371,13 @@ def execute_read_only_query(
360371

361372
# Execute the query
362373
if params:
363-
rs = self.session.execute(query, params)
374+
session = await self.get_session()
375+
376+
rs = session.execute(query, params)
364377
else:
365-
rs = self.session.execute(query)
378+
session = await self.get_session()
379+
380+
rs = session.execute(query)
366381

367382
# Process the results
368383
rows = []

src/amazon-keyspaces-mcp-server/awslabs/amazon_keyspaces_mcp_server/server.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from fastmcp import Context, FastMCP
2020
from loguru import logger
21-
from pydantic import Field
2221

2322
from .client import UnifiedCassandraClient
2423
from .config import AppConfig
@@ -42,13 +41,15 @@
4241
build_query_result_context,
4342
build_table_details_context,
4443
)
44+
from .models import KeyspaceInput, QueryInput, TableInput
4545
from .services import DataService, QueryAnalysisService, SchemaService
4646

4747

4848
# Remove all default handlers then add our own
4949
logger.remove()
5050
logger.add(sys.stderr, level='INFO')
5151

52+
5253
mcp = FastMCP(
5354
name=SERVER_NAME,
5455
version=SERVER_VERSION,
@@ -79,10 +80,10 @@
7980
_PROXY = None
8081

8182

82-
def get_proxy():
83+
async def get_proxy():
8384
"""Returns a singleton instance of the main Keyspaces MCP server implementation.
8485
85-
The singleton is initialized lazily.
86+
The singleton is initialized lazily when first accessed (ensuring event loop is running).
8687
"""
8788
global _PROXY # pylint: disable=global-statement
8889
if _PROXY is None:
@@ -110,72 +111,74 @@ async def list_keyspaces(
110111
ctx: Optional[Context] = None,
111112
) -> str:
112113
"""Lists all keyspaces in the Cassandra/Keyspaces database."""
113-
return await get_proxy()._handle_list_keyspaces(ctx) # pylint: disable=protected-access
114+
proxy = await get_proxy()
115+
return await proxy._handle_list_keyspaces(ctx) # pylint: disable=protected-access
114116

115117

116118
@mcp.tool(
117119
name='listTables',
118120
description='Lists all tables in a specified keyspace - args: keyspace',
119121
)
120122
async def list_tables(
121-
keyspace: str = Field(..., description='The keyspace to list tables from.'),
123+
input: KeyspaceInput,
122124
ctx: Optional[Context] = None,
123125
) -> str:
124126
"""Lists all tables in a specified keyspace."""
125-
return await get_proxy()._handle_list_tables(keyspace, ctx) # pylint: disable=protected-access
127+
proxy = await get_proxy()
128+
return await proxy._handle_list_tables(input.keyspace, ctx) # pylint: disable=protected-access
126129

127130

128131
@mcp.tool(
129132
name='describeKeyspace',
130133
description='Gets detailed information about a keyspace - args: keyspace',
131134
)
132135
async def describe_keyspace(
133-
keyspace: str = Field(..., description='The keyspace to retrieve metadata for.'),
136+
input: KeyspaceInput,
134137
ctx: Optional[Context] = None,
135138
) -> str:
136139
"""Gets detailed information about a keyspace."""
137-
return await get_proxy()._handle_describe_keyspace(keyspace, ctx) # pylint: disable=protected-access
140+
proxy = await get_proxy()
141+
return await proxy._handle_describe_keyspace(input.keyspace, ctx) # pylint: disable=protected-access
138142

139143

140144
@mcp.tool(
141145
name='describeTable',
142146
description='Gets detailed information about a table - args: keyspace, table',
143147
)
144148
async def describe_table(
145-
keyspace: str = Field(..., description='The keyspace containing the table'),
146-
table: str = Field(..., description='The name of the table to describe'),
149+
input: TableInput,
147150
ctx: Optional[Context] = None,
148151
) -> str:
149152
"""Gets detailed information about a table."""
150-
return await get_proxy()._handle_describe_table(keyspace, table, ctx) # pylint: disable=protected-access
153+
proxy = await get_proxy()
154+
return await proxy._handle_describe_table(input.keyspace, input.table, ctx) # pylint: disable=protected-access
151155

152156

153157
@mcp.tool(
154158
name='executeQuery',
155159
description='Executes a read-only SELECT query against the database - args: keyspace, query',
156160
)
157161
async def execute_query(
158-
keyspace: str = Field(..., description='The keyspace to execute the query against'),
159-
query: str = Field(..., description='The CQL SELECT query to execute'),
162+
input: QueryInput,
160163
ctx: Optional[Context] = None,
161164
) -> str:
162165
"""Executes a read-only (SELECT) query against the database."""
163-
return await get_proxy()._handle_execute_query(keyspace, query, ctx) # pylint: disable=protected-access
166+
proxy = await get_proxy()
167+
return await proxy._handle_execute_query(input.keyspace, input.query, ctx) # pylint: disable=protected-access
164168

165169

166170
@mcp.tool(
167171
name='analyzeQueryPerformance',
168172
description='Analyzes the performance characteristics of a CQL query - args: keyspace, query',
169173
)
170174
async def analyze_query_performance(
171-
keyspace: str = Field(..., description='The keyspace to analyze the query against'),
172-
query: str = Field(..., description='The CQL query to analyze for performance'),
175+
input: QueryInput,
173176
ctx: Optional[Context] = None,
174177
) -> str:
175178
"""Analyzes the performance characteristics of a CQL query."""
176-
proxy = get_proxy()
179+
proxy = await get_proxy()
177180
return await proxy._handle_analyze_query_performance( # pylint: disable=protected-access
178-
keyspace, query, ctx
181+
input.keyspace, input.query, ctx
179182
)
180183

181184

@@ -196,7 +199,7 @@ def __init__(
196199
async def _handle_list_keyspaces(self, ctx: Optional[Any] = None) -> str:
197200
"""Handle the listKeyspaces tool."""
198201
try:
199-
keyspaces = self.schema_service.list_keyspaces()
202+
keyspaces = await self.schema_service.list_keyspaces()
200203

201204
# Format keyspace names as a markdown list for better display
202205
keyspace_names = [k.name for k in keyspaces]
@@ -225,7 +228,7 @@ async def _handle_list_tables(self, keyspace: str, ctx: Optional[Context] = None
225228
if not keyspace:
226229
raise ValidationError('Keyspace name is required')
227230

228-
tables = self.schema_service.list_tables(keyspace)
231+
tables = await self.schema_service.list_tables(keyspace)
229232

230233
# Format table names as a markdown list for better display
231234
table_names = [t.name for t in tables]
@@ -256,7 +259,7 @@ async def _handle_describe_keyspace(
256259
if not keyspace:
257260
raise ValidationError('Keyspace name is required')
258261

259-
keyspace_details = self.schema_service.describe_keyspace(keyspace)
262+
keyspace_details = await self.schema_service.describe_keyspace(keyspace)
260263

261264
# Format keyspace details as markdown
262265
formatted_text = f'## Keyspace: `{keyspace}`\n\n'
@@ -303,7 +306,7 @@ async def _handle_describe_table(
303306
if not table:
304307
raise ValidationError('Table name is required')
305308

306-
table_details = self.schema_service.describe_table(keyspace, table)
309+
table_details = await self.schema_service.describe_table(keyspace, table)
307310

308311
# Format table details as markdown
309312
formatted_text = f'## Table: `{keyspace}.{table}`\n\n'
@@ -382,7 +385,7 @@ async def _handle_execute_query(
382385
raise QuerySecurityError('Query contains potentially unsafe operations')
383386

384387
# Execute the query using the DataService
385-
query_results = self.data_service.execute_read_only_query(keyspace, query)
388+
query_results = await self.data_service.execute_read_only_query(keyspace, query)
386389

387390
# Format the results for display
388391
formatted_text = '## Query Results\n\n'
@@ -447,7 +450,7 @@ async def _handle_analyze_query_performance(
447450
if not query:
448451
raise ValidationError('Query is required')
449452

450-
analysis_result = self.query_analysis_service.analyze_query(keyspace, query)
453+
analysis_result = await self.query_analysis_service.analyze_query(keyspace, query)
451454

452455
# Build a user-friendly response
453456
formatted_text = '## Query Analysis Results\n\n'
@@ -478,6 +481,17 @@ async def _handle_analyze_query_performance(
478481

479482
def main():
480483
"""Run the MCP server."""
484+
import asyncio
485+
486+
# Validate connection before starting server
487+
try:
488+
proxy = asyncio.run(get_proxy())
489+
asyncio.run(proxy.schema_service.cassandra_client.get_session())
490+
logger.success('Successfully validated database connection')
491+
except Exception as e:
492+
logger.error(f'Failed to connect to database: {e}')
493+
sys.exit(1)
494+
481495
mcp.run()
482496

483497

0 commit comments

Comments
 (0)