11import csv
22import hashlib
3+ import re
34from dataclasses import dataclass , field
45from pathlib import Path
56from time import time
4849 from astrapy import AsyncCollection as AstraDBAsyncCollection
4950 from astrapy import Collection as AstraDBCollection
5051 from astrapy import DataAPIClient as AstraDBClient
52+ from astrapy import Database as AstraDB
5153
5254
5355CONNECTOR_TYPE = "astradb"
@@ -85,11 +87,10 @@ def get_client(self) -> "AstraDBClient":
8587 )
8688
8789
88- def get_astra_collection (
90+ def get_astra_db (
8991 connection_config : AstraDBConnectionConfig ,
90- collection_name : str ,
9192 keyspace : str ,
92- ) -> "AstraDBCollection " :
93+ ) -> "AstraDB " :
9394 # Build the Astra DB object.
9495 access_configs = connection_config .access_config .get_secret_value ()
9596
@@ -103,9 +104,20 @@ def get_astra_collection(
103104 token = access_configs .token ,
104105 keyspace = keyspace ,
105106 )
107+ return astra_db
108+
106109
107- # Connect to the collection
110+ def get_astra_collection (
111+ connection_config : AstraDBConnectionConfig ,
112+ collection_name : str ,
113+ keyspace : str ,
114+ ) -> "AstraDBCollection" :
115+
116+ astra_db = get_astra_db (connection_config = connection_config , keyspace = keyspace )
117+
118+ # astradb will return a collection object in all cases (even if it doesn't exist)
108119 astra_db_collection = astra_db .get_collection (name = collection_name )
120+
109121 return astra_db_collection
110122
111123
@@ -151,10 +163,11 @@ class AstraDBDownloaderConfig(DownloaderConfig):
151163
152164
153165class AstraDBUploaderConfig (UploaderConfig ):
154- collection_name : str = Field (
166+ collection_name : Optional [ str ] = Field (
155167 description = "The name of the Astra DB collection. "
156168 "Note that the collection name must only include letters, "
157- "numbers, and underscores."
169+ "numbers, and underscores." ,
170+ default = None ,
158171 )
159172 keyspace : Optional [str ] = Field (default = None , description = "The Astra DB connection keyspace." )
160173 requested_indexing_policy : Optional [dict [str , Any ]] = Field (
@@ -337,25 +350,84 @@ class AstraDBUploader(Uploader):
337350 upload_config : AstraDBUploaderConfig
338351 connector_type : str = CONNECTOR_TYPE
339352
353+ def init (self , ** kwargs : Any ) -> None :
354+ self .create_destination (** kwargs )
355+
340356 def precheck (self ) -> None :
341357 try :
342- get_astra_collection (
343- connection_config = self .connection_config ,
344- collection_name = self .upload_config .collection_name ,
345- keyspace = self .upload_config .keyspace ,
346- ).options ()
358+ if self .upload_config .collection_name :
359+ self .get_collection (collection_name = self .upload_config .collection_name ).options ()
360+ else :
361+ # check for db connection only if collection name is not provided
362+ get_astra_db (
363+ connection_config = self .connection_config ,
364+ keyspace = self .upload_config .keyspace ,
365+ )
347366 except Exception as e :
348367 logger .error (f"Failed to validate connection { e } " , exc_info = True )
349368 raise DestinationConnectionError (f"failed to validate connection: { e } " )
350369
351370 @requires_dependencies (["astrapy" ], extras = "astradb" )
352- def get_collection (self ) -> "AstraDBCollection" :
371+ def get_collection (self , collection_name : Optional [ str ] = None ) -> "AstraDBCollection" :
353372 return get_astra_collection (
354373 connection_config = self .connection_config ,
355- collection_name = self .upload_config .collection_name ,
374+ collection_name = collection_name or self .upload_config .collection_name ,
356375 keyspace = self .upload_config .keyspace ,
357376 )
358377
378+ def _collection_exists (self , collection_name : str ):
379+ from astrapy .exceptions import CollectionNotFoundException
380+
381+ collection = get_astra_collection (
382+ connection_config = self .connection_config ,
383+ collection_name = collection_name ,
384+ keyspace = self .upload_config .keyspace ,
385+ )
386+
387+ try :
388+ collection .options ()
389+ return True
390+ except CollectionNotFoundException :
391+ return False
392+ except Exception as e :
393+ logger .error (f"failed to check if astra collection exists : { e } " )
394+ raise DestinationConnectionError (f"failed to check if astra collection exists : { e } " )
395+
396+ def format_destination_name (self , destination_name : str ) -> str :
397+ # AstraDB collection naming requirements:
398+ # must be below 50 characters
399+ # must be lowercase alphanumeric and underscores only
400+ formatted = re .sub (r"[^a-z0-9]" , "_" , destination_name .lower ())
401+ return formatted
402+
403+ def create_destination (
404+ self ,
405+ vector_length : int ,
406+ destination_name : str = "unstructuredautocreated" ,
407+ similarity_metric : Optional [str ] = "cosine" ,
408+ ** kwargs : Any ,
409+ ) -> bool :
410+ destination_name = self .format_destination_name (destination_name )
411+ collection_name = self .upload_config .collection_name or destination_name
412+ self .upload_config .collection_name = collection_name
413+
414+ if not self ._collection_exists (collection_name ):
415+ astra_db = get_astra_db (
416+ connection_config = self .connection_config , keyspace = self .upload_config .keyspace
417+ )
418+ logger .info (
419+ f"creating default astra collection '{ collection_name } ' with dimension "
420+ f"{ vector_length } and metric { similarity_metric } "
421+ )
422+ astra_db .create_collection (
423+ collection_name ,
424+ dimension = vector_length ,
425+ metric = similarity_metric ,
426+ )
427+ return True
428+ logger .debug (f"collection with name '{ collection_name } ' already exists, skipping creation" )
429+ return False
430+
359431 def delete_by_record_id (self , collection : "AstraDBCollection" , file_data : FileData ):
360432 logger .debug (
361433 f"deleting records from collection { collection .name } "
0 commit comments