1+ import asyncio
12import csv
23import hashlib
4+ import os
35import re
46from dataclasses import dataclass , field
57from pathlib import Path
810
911from pydantic import BaseModel , Field , Secret
1012
11- from unstructured_ingest import __name__ as integration_name
1213from unstructured_ingest .__version__ import __version__ as integration_version
1314from unstructured_ingest .data_types .file_data import (
1415 BatchFileData ,
@@ -83,10 +84,8 @@ def get_client(self) -> "AstraDBClient":
8384
8485 # Create a client object to interact with the Astra DB
8586 # caller_name/version for Astra DB tracking
86- return AstraDBClient (
87- caller_name = integration_name ,
88- caller_version = integration_version ,
89- )
87+ user_agent = os .getenv ("UNSTRUCTURED_USER_AGENT" , "unstructuredio_oss" )
88+ return AstraDBClient (callers = [(user_agent , integration_version )])
9089
9190
9291def get_astra_db (
@@ -141,7 +140,7 @@ async def get_async_astra_collection(
141140 )
142141
143142 # Get async collection from AsyncDatabase
144- async_astra_db_collection = await async_astra_db .get_collection (name = collection_name )
143+ async_astra_db_collection = async_astra_db .get_collection (name = collection_name )
145144 return async_astra_db_collection
146145
147146
@@ -360,13 +359,22 @@ class AstraDBUploader(Uploader):
360359 upload_config : AstraDBUploaderConfig
361360 connector_type : str = CONNECTOR_TYPE
362361
362+ def is_async (self ) -> bool :
363+ return True
364+
363365 def init (self , ** kwargs : Any ) -> None :
364366 self .create_destination (** kwargs )
365367
368+ @requires_dependencies (["astrapy" ], extras = "astradb" )
366369 def precheck (self ) -> None :
367370 try :
368371 if self .upload_config .collection_name :
369- self .get_collection (collection_name = self .upload_config .collection_name ).options ()
372+ collection = get_astra_collection (
373+ connection_config = self .connection_config ,
374+ collection_name = self .upload_config .collection_name ,
375+ keyspace = self .upload_config .keyspace ,
376+ )
377+ collection .options ()
370378 else :
371379 # check for db connection only if collection name is not provided
372380 get_astra_db (
@@ -377,17 +385,7 @@ def precheck(self) -> None:
377385 logger .error (f"Failed to validate connection { e } " , exc_info = True )
378386 raise DestinationConnectionError (f"failed to validate connection: { e } " )
379387
380- @requires_dependencies (["astrapy" ], extras = "astradb" )
381- def get_collection (self , collection_name : Optional [str ] = None ) -> "AstraDBCollection" :
382- return get_astra_collection (
383- connection_config = self .connection_config ,
384- collection_name = collection_name or self .upload_config .collection_name ,
385- keyspace = self .upload_config .keyspace ,
386- )
387-
388388 def _collection_exists (self , collection_name : str ):
389- from astrapy .exceptions import CollectionNotFoundException
390-
391389 collection = get_astra_collection (
392390 connection_config = self .connection_config ,
393391 collection_name = collection_name ,
@@ -397,8 +395,10 @@ def _collection_exists(self, collection_name: str):
397395 try :
398396 collection .options ()
399397 return True
400- except CollectionNotFoundException :
401- return False
398+ except RuntimeError as e :
399+ if "not found" in str (e ):
400+ return False
401+ raise DestinationConnectionError (f"failed to check if astra collection exists : { e } " )
402402 except Exception as e :
403403 logger .error (f"failed to check if astra collection exists : { e } " )
404404 raise DestinationConnectionError (f"failed to check if astra collection exists : { e } " )
@@ -422,51 +422,65 @@ def create_destination(
422422 self .upload_config .collection_name = collection_name
423423
424424 if not self ._collection_exists (collection_name ):
425+ from astrapy .info import CollectionDefinition
426+
425427 astra_db = get_astra_db (
426428 connection_config = self .connection_config , keyspace = self .upload_config .keyspace
427429 )
428430 logger .info (
429431 f"creating default astra collection '{ collection_name } ' with dimension "
430432 f"{ vector_length } and metric { similarity_metric } "
431433 )
432- astra_db .create_collection (
433- collection_name ,
434- dimension = vector_length ,
435- metric = similarity_metric ,
434+ definition = (
435+ CollectionDefinition .builder ()
436+ .set_vector_dimension (dimension = vector_length )
437+ .set_vector_metric (similarity_metric )
438+ .build ()
436439 )
440+ (astra_db .create_collection (collection_name , definition = definition ),)
437441 return True
438442 logger .debug (f"collection with name '{ collection_name } ' already exists, skipping creation" )
439443 return False
440444
441- def delete_by_record_id (self , collection : "AstraDBCollection " , file_data : FileData ):
445+ async def delete_by_record_id (self , collection : "AstraDBAsyncCollection " , file_data : FileData ):
442446 logger .debug (
443447 f"deleting records from collection { collection .name } "
444448 f"with { self .upload_config .record_id_key } "
445449 f"set to { file_data .identifier } "
446450 )
447451 delete_filter = {self .upload_config .record_id_key : {"$eq" : file_data .identifier }}
448- delete_resp = collection .delete_many (filter = delete_filter )
452+ delete_resp = await collection .delete_many (filter = delete_filter )
449453 logger .debug (
450454 f"deleted { delete_resp .deleted_count } records from collection { collection .name } "
451455 )
452456
453- def run_data (self , data : list [dict ], file_data : FileData , ** kwargs : Any ) -> None :
457+ async def run_data (self , data : list [dict ], file_data : FileData , ** kwargs : Any ) -> None :
454458 logger .info (
455459 f"writing { len (data )} objects to destination "
456460 f"collection { self .upload_config .collection_name } "
457461 )
458462
459463 astra_db_batch_size = self .upload_config .batch_size
460- collection = self .get_collection ()
461-
462- self .delete_by_record_id (collection = collection , file_data = file_data )
464+ async_astra_collection = await get_async_astra_collection (
465+ connection_config = self .connection_config ,
466+ collection_name = self .upload_config .collection_name ,
467+ keyspace = self .upload_config .keyspace ,
468+ )
463469
464- for chunk in batch_generator (data , astra_db_batch_size ):
465- collection .insert_many (chunk )
470+ await self .delete_by_record_id (collection = async_astra_collection , file_data = file_data )
471+ await asyncio .gather (
472+ * [
473+ async_astra_collection .insert_many (chunk )
474+ for chunk in batch_generator (data , astra_db_batch_size )
475+ ]
476+ )
466477
467- def run (self , path : Path , file_data : FileData , ** kwargs : Any ) -> None :
478+ async def run_async (self , path : Path , file_data : FileData , ** kwargs : Any ) -> None :
468479 data = get_json_data (path = path )
469- self .run_data (data = data , file_data = file_data , ** kwargs )
480+ await self .run_data (data = data , file_data = file_data )
481+
482+ def run (self , ** kwargs : Any ) -> Any :
483+ raise NotImplementedError ("Use astradb run_async instead" )
470484
471485
472486astra_db_source_entry = SourceRegistryEntry (
0 commit comments