2525
2626from cassandra .auth import PlainTextAuthProvider
2727from cassandra .cluster import Cluster , Session
28- from cassandra .io .asyncioreactor import AsyncioConnection
2928
3029from .consts import (
3130 CERT_DIRECTORY ,
@@ -70,19 +69,23 @@ def __init__(self, database_config: DatabaseConfig):
7069 """Initialize the client with the given configuration."""
7170 self .database_config = database_config
7271 self .is_keyspaces = database_config .use_keyspaces
73-
74- # Initialize session for the configured database type (Keyspaces or Cassandra)
75- try :
76- if self .is_keyspaces :
77- self .session = self ._create_keyspaces_session ()
78- logger .info ('Connected to Amazon Keyspaces' )
79- else :
80- self .session = self ._create_cassandra_session ()
81- logger .info ('Connected to Cassandra cluster' )
82- except Exception as e :
83- target = 'Amazon Keyspaces' if self .is_keyspaces else 'Cassandra cluster'
84- logger .error ('Failed to connect to %s: %s' , target , str (e ))
85- raise RuntimeError (f'Failed to connect to { target } : { str (e )} ' ) from e
72+ self ._session : Optional [Session ] = None
73+
74+ async def get_session (self ) -> Session :
75+ """Get or create database session lazily."""
76+ if self ._session is None :
77+ try :
78+ if self .is_keyspaces :
79+ self ._session = self ._create_keyspaces_session ()
80+ logger .info ('Connected to Amazon Keyspaces' )
81+ else :
82+ self ._session = self ._create_cassandra_session ()
83+ logger .info ('Connected to Cassandra cluster' )
84+ except Exception as e :
85+ target = 'Amazon Keyspaces' if self .is_keyspaces else 'Cassandra cluster'
86+ logger .error ('Failed to connect to %s: %s' , target , str (e ))
87+ raise RuntimeError (f'Failed to connect to { target } : { str (e )} ' ) from e
88+ return self ._session
8689
8790 def _create_cassandra_session (self ) -> Session :
8891 """Create a session for Apache Cassandra."""
@@ -102,8 +105,6 @@ def _create_cassandra_session(self) -> Session:
102105 connect_timeout = int (CONNECTION_TIMEOUT ),
103106 )
104107
105- cluster .connection_class = AsyncioConnection
106-
107108 return cluster .connect ()
108109
109110 def _create_keyspaces_session (self ) -> Session :
@@ -144,8 +145,6 @@ def _create_keyspaces_session(self) -> Session:
144145 connect_timeout = int (CONNECTION_TIMEOUT ),
145146 )
146147
147- cluster .connection_class = AsyncioConnection
148-
149148 return cluster .connect ()
150149
151150 def _create_ssl_context_for_keyspaces (self ) -> ssl .SSLContext :
@@ -171,13 +170,14 @@ def is_using_keyspaces(self) -> bool:
171170 """Check if the client is using Amazon Keyspaces."""
172171 return self .is_keyspaces
173172
174- def list_keyspaces (self ) -> List [KeyspaceInfo ]:
173+ async def list_keyspaces (self ) -> List [KeyspaceInfo ]:
175174 """List all keyspaces in the database."""
176175 keyspaces = []
177176
178177 try :
179178 query = 'SELECT keyspace_name, replication FROM system_schema.keyspaces'
180- rows = self .session .execute (query )
179+ session = await self .get_session ()
180+ rows = session .execute (query )
181181
182182 for row in rows :
183183 name = row .keyspace_name
@@ -199,13 +199,14 @@ def list_keyspaces(self) -> List[KeyspaceInfo]:
199199 logger .error ('Error listing keyspaces: %s' , str (e ))
200200 raise RuntimeError (f'Failed to list keyspaces: { str (e )} ' ) from e
201201
202- def list_tables (self , keyspace_name : str ) -> List [TableInfo ]:
202+ async def list_tables (self , keyspace_name : str ) -> List [TableInfo ]:
203203 """List all tables in a keyspace."""
204204 tables = []
205205
206206 try :
207207 query = 'SELECT table_name FROM system_schema.tables WHERE keyspace_name = %s'
208- rows = self .session .execute (query , [keyspace_name ])
208+ session = await self .get_session ()
209+ rows = session .execute (query , [keyspace_name ])
209210
210211 for row in rows :
211212 name = row .table_name
@@ -218,11 +219,13 @@ def list_tables(self, keyspace_name: str) -> List[TableInfo]:
218219 f'Failed to list tables for keyspace { keyspace_name } : { str (e )} '
219220 ) from e
220221
221- def describe_keyspace (self , keyspace_name : str ) -> Dict [str , Any ]:
222+ async def describe_keyspace (self , keyspace_name : str ) -> Dict [str , Any ]:
222223 """Get detailed information about a keyspace."""
223224 try :
224225 query = 'SELECT * FROM system_schema.keyspaces WHERE keyspace_name = %s'
225- row = self .session .execute (query , [keyspace_name ]).one ()
226+ session = await self .get_session ()
227+
228+ row = session .execute (query , [keyspace_name ]).one ()
226229
227230 if not row :
228231 raise RuntimeError (f'Keyspace not found: { keyspace_name } ' )
@@ -245,14 +248,16 @@ def describe_keyspace(self, keyspace_name: str) -> Dict[str, Any]:
245248 logger .error ('Error describing keyspace %s: %s' , keyspace_name , str (e ))
246249 raise RuntimeError (f'Failed to describe keyspace { keyspace_name } : { str (e )} ' ) from e
247250
248- def describe_table (self , keyspace_name : str , table_name : str ) -> Dict [str , Any ]:
251+ async def describe_table (self , keyspace_name : str , table_name : str ) -> Dict [str , Any ]:
249252 """Get detailed information about a table."""
250253 try :
251254 query = (
252255 'SELECT * FROM system_schema.tables WHERE '
253256 'keyspace_name = %s AND table_name = %s'
254257 )
255- table_row = self .session .execute (query , [keyspace_name , table_name ]).one ()
258+ session = await self .get_session ()
259+
260+ table_row = session .execute (query , [keyspace_name , table_name ]).one ()
256261
257262 if not table_row :
258263 raise RuntimeError (f'Table not found: { keyspace_name } .{ table_name } ' )
@@ -266,7 +271,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
266271 query = (
267272 'SELECT * FROM system_schema.columns WHERE keyspace_name = %s AND table_name = %s'
268273 )
269- column_rows = self .session .execute (query , [keyspace_name , table_name ])
274+ session = await self .get_session ()
275+
276+ column_rows = session .execute (query , [keyspace_name , table_name ])
270277
271278 columns = []
272279 for column_row in column_rows :
@@ -284,7 +291,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
284291 query = (
285292 'SELECT * FROM system_schema.indexes WHERE keyspace_name = %s AND table_name = %s'
286293 )
287- index_rows = self .session .execute (query , [keyspace_name , table_name ])
294+ session = await self .get_session ()
295+
296+ index_rows = session .execute (query , [keyspace_name , table_name ])
288297
289298 indexes = []
290299 for index_row in index_rows :
@@ -307,7 +316,9 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
307316 'SELECT custom_properties FROM system_schema_mcs.tables '
308317 'WHERE keyspace_name = %s AND table_name = %s'
309318 )
310- capacity_row = self .session .execute (
319+ session = await self .get_session ()
320+
321+ capacity_row = session .execute (
311322 query , [keyspace_name , table_name ]
312323 ).one ()
313324
@@ -339,7 +350,7 @@ def describe_table(self, keyspace_name: str, table_name: str) -> Dict[str, Any]:
339350 f'Failed to describe table { keyspace_name } .{ table_name } : { str (e )} '
340351 ) from e
341352
342- def execute_read_only_query (
353+ async def execute_read_only_query (
343354 self , query : str , params : Optional [List [Any ]] = None
344355 ) -> Dict [str , Any ]:
345356 """Execute a read-only SELECT query against the database."""
@@ -360,9 +371,13 @@ def execute_read_only_query(
360371
361372 # Execute the query
362373 if params :
363- rs = self .session .execute (query , params )
374+ session = await self .get_session ()
375+
376+ rs = session .execute (query , params )
364377 else :
365- rs = self .session .execute (query )
378+ session = await self .get_session ()
379+
380+ rs = session .execute (query )
366381
367382 # Process the results
368383 rows = []
0 commit comments