Skip to content

Commit 2825774

Browse files
authored
feat: opinionated writes for AstraDB destination (#391)
* make collection optional; wip create dest * name formatter * code * test and clean * bump version * better none handling; fixes * move import into functions * move collection verification to callers; rewrite exists method * tidy
1 parent 31fae37 commit 2825774

File tree

4 files changed

+123
-20
lines changed

4 files changed

+123
-20
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
## 0.5.9-dev0
1+
## 0.5.9-dev1
2+
3+
### Features
4+
5+
* Add auto create collection support for AstraDB destination
26

37
### Fixes
48

test/integration/connectors/test_astradb.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,7 @@ def test_precheck_succeeds_indexer(connection_config: AstraDBConnectionConfig):
5656
connection_config=connection_config,
5757
index_config=AstraDBIndexerConfig(collection_name=EXISTENT_COLLECTION_NAME),
5858
)
59-
uploader = AstraDBUploader(
60-
connection_config=connection_config,
61-
upload_config=AstraDBUploaderConfig(collection_name=EXISTENT_COLLECTION_NAME),
62-
)
6359
indexer.precheck()
64-
uploader.precheck()
6560

6661

6762
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
@@ -73,6 +68,12 @@ def test_precheck_succeeds_uploader(connection_config: AstraDBConnectionConfig):
7368
)
7469
uploader.precheck()
7570

71+
uploader2 = AstraDBUploader(
72+
connection_config=connection_config,
73+
upload_config=AstraDBUploaderConfig(),
74+
)
75+
uploader2.precheck()
76+
7677

7778
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, VECTOR_DB_TAG)
7879
@requires_env("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT")
@@ -216,6 +217,32 @@ async def test_astra_search_destination(
216217
)
217218

218219

220+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
221+
@requires_env("ASTRA_DB_API_ENDPOINT", "ASTRA_DB_APPLICATION_TOKEN")
222+
def test_astra_create_destination():
223+
env_data = get_env_data()
224+
connection_config = AstraDBConnectionConfig(
225+
access_config=AstraDBAccessConfig(api_endpoint=env_data.api_endpoint, token=env_data.token),
226+
)
227+
uploader = AstraDBUploader(
228+
connection_config=connection_config,
229+
upload_config=AstraDBUploaderConfig(),
230+
)
231+
collection_name = "system_created-123"
232+
formatted_collection_name = "system_created_123"
233+
created = uploader.create_destination(destination_name=collection_name, vector_length=3072)
234+
assert created
235+
assert uploader.upload_config.collection_name == formatted_collection_name
236+
237+
created = uploader.create_destination(destination_name=collection_name, vector_length=3072)
238+
assert not created
239+
240+
# cleanup
241+
client = AstraDBClient()
242+
db = client.get_database(api_endpoint=env_data.api_endpoint, token=env_data.token)
243+
db.drop_collection(formatted_collection_name)
244+
245+
219246
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
220247
@pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
221248
def test_astra_stager(

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.9-dev0" # pragma: no cover
1+
__version__ = "0.5.9-dev1" # pragma: no cover

unstructured_ingest/v2/processes/connectors/astradb.py

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import csv
22
import hashlib
3+
import re
34
from dataclasses import dataclass, field
45
from pathlib import Path
56
from time import time
@@ -48,6 +49,7 @@
4849
from astrapy import AsyncCollection as AstraDBAsyncCollection
4950
from astrapy import Collection as AstraDBCollection
5051
from astrapy import DataAPIClient as AstraDBClient
52+
from astrapy import Database as AstraDB
5153

5254

5355
CONNECTOR_TYPE = "astradb"
@@ -85,11 +87,10 @@ def get_client(self) -> "AstraDBClient":
8587
)
8688

8789

88-
def get_astra_collection(
90+
def get_astra_db(
8991
connection_config: AstraDBConnectionConfig,
90-
collection_name: str,
9192
keyspace: str,
92-
) -> "AstraDBCollection":
93+
) -> "AstraDB":
9394
# Build the Astra DB object.
9495
access_configs = connection_config.access_config.get_secret_value()
9596

@@ -103,9 +104,20 @@ def get_astra_collection(
103104
token=access_configs.token,
104105
keyspace=keyspace,
105106
)
107+
return astra_db
108+
106109

107-
# Connect to the collection
110+
def get_astra_collection(
111+
connection_config: AstraDBConnectionConfig,
112+
collection_name: str,
113+
keyspace: str,
114+
) -> "AstraDBCollection":
115+
116+
astra_db = get_astra_db(connection_config=connection_config, keyspace=keyspace)
117+
118+
# astradb will return a collection object in all cases (even if it doesn't exist)
108119
astra_db_collection = astra_db.get_collection(name=collection_name)
120+
109121
return astra_db_collection
110122

111123

@@ -151,10 +163,11 @@ class AstraDBDownloaderConfig(DownloaderConfig):
151163

152164

153165
class AstraDBUploaderConfig(UploaderConfig):
154-
collection_name: str = Field(
166+
collection_name: Optional[str] = Field(
155167
description="The name of the Astra DB collection. "
156168
"Note that the collection name must only include letters, "
157-
"numbers, and underscores."
169+
"numbers, and underscores.",
170+
default=None,
158171
)
159172
keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
160173
requested_indexing_policy: Optional[dict[str, Any]] = Field(
@@ -337,25 +350,84 @@ class AstraDBUploader(Uploader):
337350
upload_config: AstraDBUploaderConfig
338351
connector_type: str = CONNECTOR_TYPE
339352

353+
def init(self, **kwargs: Any) -> None:
354+
self.create_destination(**kwargs)
355+
340356
def precheck(self) -> None:
341357
try:
342-
get_astra_collection(
343-
connection_config=self.connection_config,
344-
collection_name=self.upload_config.collection_name,
345-
keyspace=self.upload_config.keyspace,
346-
).options()
358+
if self.upload_config.collection_name:
359+
self.get_collection(collection_name=self.upload_config.collection_name).options()
360+
else:
361+
# check for db connection only if collection name is not provided
362+
get_astra_db(
363+
connection_config=self.connection_config,
364+
keyspace=self.upload_config.keyspace,
365+
)
347366
except Exception as e:
348367
logger.error(f"Failed to validate connection {e}", exc_info=True)
349368
raise DestinationConnectionError(f"failed to validate connection: {e}")
350369

351370
@requires_dependencies(["astrapy"], extras="astradb")
352-
def get_collection(self) -> "AstraDBCollection":
371+
def get_collection(self, collection_name: Optional[str] = None) -> "AstraDBCollection":
353372
return get_astra_collection(
354373
connection_config=self.connection_config,
355-
collection_name=self.upload_config.collection_name,
374+
collection_name=collection_name or self.upload_config.collection_name,
356375
keyspace=self.upload_config.keyspace,
357376
)
358377

378+
def _collection_exists(self, collection_name: str):
379+
from astrapy.exceptions import CollectionNotFoundException
380+
381+
collection = get_astra_collection(
382+
connection_config=self.connection_config,
383+
collection_name=collection_name,
384+
keyspace=self.upload_config.keyspace,
385+
)
386+
387+
try:
388+
collection.options()
389+
return True
390+
except CollectionNotFoundException:
391+
return False
392+
except Exception as e:
393+
logger.error(f"failed to check if astra collection exists : {e}")
394+
raise DestinationConnectionError(f"failed to check if astra collection exists : {e}")
395+
396+
def format_destination_name(self, destination_name: str) -> str:
397+
# AstraDB collection naming requirements:
398+
# must be below 50 characters
399+
# must be lowercase alphanumeric and underscores only
400+
formatted = re.sub(r"[^a-z0-9]", "_", destination_name.lower())
401+
return formatted
402+
403+
def create_destination(
404+
self,
405+
vector_length: int,
406+
destination_name: str = "unstructuredautocreated",
407+
similarity_metric: Optional[str] = "cosine",
408+
**kwargs: Any,
409+
) -> bool:
410+
destination_name = self.format_destination_name(destination_name)
411+
collection_name = self.upload_config.collection_name or destination_name
412+
self.upload_config.collection_name = collection_name
413+
414+
if not self._collection_exists(collection_name):
415+
astra_db = get_astra_db(
416+
connection_config=self.connection_config, keyspace=self.upload_config.keyspace
417+
)
418+
logger.info(
419+
f"creating default astra collection '{collection_name}' with dimension "
420+
f"{vector_length} and metric {similarity_metric}"
421+
)
422+
astra_db.create_collection(
423+
collection_name,
424+
dimension=vector_length,
425+
metric=similarity_metric,
426+
)
427+
return True
428+
logger.debug(f"collection with name '{collection_name}' already exists, skipping creation")
429+
return False
430+
359431
def delete_by_record_id(self, collection: "AstraDBCollection", file_data: FileData):
360432
logger.debug(
361433
f"deleting records from collection {collection.name} "

0 commit comments

Comments
 (0)