Skip to content

Commit a7a53f6

Browse files
authored
feat/migrate astra db (#3294)
### Description Move astradb destination connector over to the new v2 ingest framework
1 parent 3f581e6 commit a7a53f6

File tree

6 files changed

+244
-6
lines changed

6 files changed

+244
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## 0.14.9-dev1
1+
## 0.14.9-dev2
22

33
### Enhancements
44

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.14.9-dev1" # pragma: no cover
1+
__version__ = "0.14.9-dev2" # pragma: no cover

unstructured/ingest/v2/cli/cmds/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import click
44

5+
from .astra import astra_dest_cmd
56
from .chroma import chroma_dest_cmd
67
from .elasticsearch import elasticsearch_dest_cmd, elasticsearch_src_cmd
78
from .fsspec.azure import azure_dest_cmd, azure_src_cmd
@@ -36,6 +37,7 @@
3637
)
3738

3839
dest_cmds = [
40+
astra_dest_cmd,
3941
azure_dest_cmd,
4042
box_dest_cmd,
4143
chroma_dest_cmd,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from dataclasses import dataclass
2+
3+
import click
4+
5+
from unstructured.ingest.v2.cli.base import DestCmd
6+
from unstructured.ingest.v2.cli.interfaces import CliConfig
7+
from unstructured.ingest.v2.cli.utils import Dict
8+
from unstructured.ingest.v2.processes.connectors.astra import CONNECTOR_TYPE
9+
10+
11+
@dataclass
12+
class AstraCliConnectionConfig(CliConfig):
13+
@staticmethod
14+
def get_cli_options() -> list[click.Option]:
15+
options = [
16+
click.Option(
17+
["--token"],
18+
required=True,
19+
type=str,
20+
help="Astra DB Token with access to the database.",
21+
envvar="ASTRA_DB_TOKEN",
22+
show_envvar=True,
23+
),
24+
click.Option(
25+
["--api-endpoint"],
26+
required=True,
27+
type=str,
28+
help="The API endpoint for the Astra DB.",
29+
envvar="ASTRA_DB_ENDPOINT",
30+
show_envvar=True,
31+
),
32+
]
33+
return options
34+
35+
36+
@dataclass
37+
class AstraCliUploaderConfig(CliConfig):
38+
@staticmethod
39+
def get_cli_options() -> list[click.Option]:
40+
options = [
41+
click.Option(
42+
["--collection-name"],
43+
required=False,
44+
type=str,
45+
help="The name of the Astra DB collection to write into. "
46+
"Note that the collection name must only include letters, "
47+
"numbers, and underscores.",
48+
),
49+
click.Option(
50+
["--embedding-dimension"],
51+
required=True,
52+
default=384,
53+
type=int,
54+
help="The dimensionality of the embeddings",
55+
),
56+
click.Option(
57+
["--namespace"],
58+
required=False,
59+
default=None,
60+
type=str,
61+
help="The Astra DB namespace to write into.",
62+
),
63+
click.Option(
64+
["--requested-indexing-policy"],
65+
required=False,
66+
default=None,
67+
type=Dict(),
68+
help="The indexing policy to use for the collection."
69+
'example: \'{"deny": ["metadata"]}\' ',
70+
),
71+
click.Option(
72+
["--batch-size"],
73+
default=20,
74+
type=int,
75+
help="Number of records per batch",
76+
),
77+
]
78+
return options
79+
80+
81+
astra_dest_cmd = DestCmd(
82+
cmd_name=CONNECTOR_TYPE,
83+
connection_config=AstraCliConnectionConfig,
84+
uploader_config=AstraCliUploaderConfig,
85+
)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import json
2+
from dataclasses import dataclass, field
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Any, Optional
5+
6+
from unstructured import __name__ as integration_name
7+
from unstructured.__version__ import __version__ as integration_version
8+
from unstructured.ingest.enhanced_dataclass import enhanced_field
9+
from unstructured.ingest.utils.data_prep import chunk_generator
10+
from unstructured.ingest.v2.interfaces import (
11+
AccessConfig,
12+
ConnectionConfig,
13+
FileData,
14+
UploadContent,
15+
Uploader,
16+
UploaderConfig,
17+
UploadStager,
18+
UploadStagerConfig,
19+
)
20+
from unstructured.ingest.v2.logger import logger
21+
from unstructured.ingest.v2.processes.connector_registry import (
22+
DestinationRegistryEntry,
23+
add_destination_entry,
24+
)
25+
from unstructured.utils import requires_dependencies
26+
27+
if TYPE_CHECKING:
28+
from astrapy.db import AstraDBCollection
29+
30+
CONNECTOR_TYPE = "astra"
31+
32+
33+
@dataclass
34+
class AstraAccessConfig(AccessConfig):
35+
token: str
36+
api_endpoint: str
37+
38+
39+
@dataclass
40+
class AstraConnectionConfig(ConnectionConfig):
41+
connection_type: str = CONNECTOR_TYPE
42+
access_config: AstraAccessConfig = enhanced_field(sensitive=True)
43+
44+
45+
@dataclass
46+
class AstraUploadStagerConfig(UploadStagerConfig):
47+
pass
48+
49+
50+
@dataclass
51+
class AstraUploadStager(UploadStager):
52+
upload_stager_config: AstraUploadStagerConfig = field(
53+
default_factory=lambda: AstraUploadStagerConfig()
54+
)
55+
56+
def conform_dict(self, element_dict: dict) -> dict:
57+
return {
58+
"$vector": element_dict.pop("embeddings", None),
59+
"content": element_dict.pop("text", None),
60+
"metadata": element_dict,
61+
}
62+
63+
def run(
64+
self,
65+
elements_filepath: Path,
66+
file_data: FileData,
67+
output_dir: Path,
68+
output_filename: str,
69+
**kwargs: Any,
70+
) -> Path:
71+
with open(elements_filepath) as elements_file:
72+
elements_contents = json.load(elements_file)
73+
conformed_elements = []
74+
for element in elements_contents:
75+
conformed_elements.append(self.conform_dict(element_dict=element))
76+
output_path = Path(output_dir) / Path(f"{output_filename}.json")
77+
with open(output_path, "w") as output_file:
78+
json.dump(conformed_elements, output_file)
79+
return output_path
80+
81+
82+
@dataclass
83+
class AstraUploaderConfig(UploaderConfig):
84+
collection_name: str
85+
embedding_dimension: int
86+
namespace: Optional[str] = None
87+
requested_indexing_policy: Optional[dict[str, Any]] = None
88+
batch_size: int = 20
89+
90+
91+
@dataclass
92+
class AstraUploader(Uploader):
93+
connection_config: AstraConnectionConfig
94+
upload_config: AstraUploaderConfig
95+
96+
@requires_dependencies(["astrapy"], extras="astra")
97+
def get_collection(self) -> "AstraDBCollection":
98+
from astrapy.db import AstraDB
99+
100+
# Get the collection_name and embedding dimension
101+
collection_name = self.upload_config.collection_name
102+
embedding_dimension = self.upload_config.embedding_dimension
103+
requested_indexing_policy = self.upload_config.requested_indexing_policy
104+
105+
# If the user has requested an indexing policy, pass it to the AstraDB
106+
options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None
107+
108+
# Build the Astra DB object.
109+
# caller_name/version for AstraDB tracking
110+
astra_db = AstraDB(
111+
api_endpoint=self.connection_config.access_config.api_endpoint,
112+
token=self.connection_config.access_config.token,
113+
namespace=self.upload_config.namespace,
114+
caller_name=integration_name,
115+
caller_version=integration_version,
116+
)
117+
118+
# Create and connect to the newly created collection
119+
astra_db_collection = astra_db.create_collection(
120+
collection_name=collection_name,
121+
dimension=embedding_dimension,
122+
options=options,
123+
)
124+
return astra_db_collection
125+
126+
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
127+
elements_dict = []
128+
for content in contents:
129+
with open(content.path) as elements_file:
130+
elements = json.load(elements_file)
131+
elements_dict.extend(elements)
132+
133+
logger.info(
134+
f"writing {len(elements_dict)} objects to destination "
135+
f"collection {self.upload_config.collection_name}"
136+
)
137+
138+
astra_batch_size = self.upload_config.batch_size
139+
collection = self.get_collection()
140+
141+
for chunk in chunk_generator(elements_dict, astra_batch_size):
142+
collection.insert_many(chunk)
143+
144+
145+
add_destination_entry(
146+
destination_type=CONNECTOR_TYPE,
147+
entry=DestinationRegistryEntry(
148+
connection_config=AstraConnectionConfig,
149+
upload_stager_config=AstraUploadStagerConfig,
150+
upload_stager=AstraUploadStager,
151+
uploader_config=AstraUploaderConfig,
152+
uploader=AstraUploader,
153+
),
154+
)

unstructured/ingest/v2/processes/connectors/chroma.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@
3131
if TYPE_CHECKING:
3232
from chromadb import Client
3333

34-
35-
import typing as t
36-
3734
CONNECTOR_TYPE = "chroma"
3835

3936

@@ -165,7 +162,7 @@ def upsert_batch(self, collection, batch):
165162
raise ValueError(f"chroma error: {e}") from e
166163

167164
@staticmethod
168-
def prepare_chroma_list(chunk: t.Tuple[t.Dict[str, t.Any]]) -> t.Dict[str, t.List[t.Any]]:
165+
def prepare_chroma_list(chunk: tuple[dict[str, Any]]) -> dict[str, list[Any]]:
169166
"""Helper function to break a tuple of dicts into list of parallel lists for ChromaDb.
170167
({'id':1}, {'id':2}, {'id':3}) -> {'ids':[1,2,3]}"""
171168
chroma_dict = {}

0 commit comments

Comments
 (0)