|
| 1 | +import json |
| 2 | +from dataclasses import dataclass, field |
| 3 | +from pathlib import Path |
| 4 | +from typing import TYPE_CHECKING, Any, Optional |
| 5 | + |
| 6 | +from unstructured import __name__ as integration_name |
| 7 | +from unstructured.__version__ import __version__ as integration_version |
| 8 | +from unstructured.ingest.enhanced_dataclass import enhanced_field |
| 9 | +from unstructured.ingest.utils.data_prep import chunk_generator |
| 10 | +from unstructured.ingest.v2.interfaces import ( |
| 11 | + AccessConfig, |
| 12 | + ConnectionConfig, |
| 13 | + FileData, |
| 14 | + UploadContent, |
| 15 | + Uploader, |
| 16 | + UploaderConfig, |
| 17 | + UploadStager, |
| 18 | + UploadStagerConfig, |
| 19 | +) |
| 20 | +from unstructured.ingest.v2.logger import logger |
| 21 | +from unstructured.ingest.v2.processes.connector_registry import ( |
| 22 | + DestinationRegistryEntry, |
| 23 | + add_destination_entry, |
| 24 | +) |
| 25 | +from unstructured.utils import requires_dependencies |
| 26 | + |
| 27 | +if TYPE_CHECKING: |
| 28 | + from astrapy.db import AstraDBCollection |
| 29 | + |
| 30 | +CONNECTOR_TYPE = "astra" |
| 31 | + |
| 32 | + |
| 33 | +@dataclass |
| 34 | +class AstraAccessConfig(AccessConfig): |
| 35 | + token: str |
| 36 | + api_endpoint: str |
| 37 | + |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class AstraConnectionConfig(ConnectionConfig): |
| 41 | + connection_type: str = CONNECTOR_TYPE |
| 42 | + access_config: AstraAccessConfig = enhanced_field(sensitive=True) |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class AstraUploadStagerConfig(UploadStagerConfig): |
| 47 | + pass |
| 48 | + |
| 49 | + |
| 50 | +@dataclass |
| 51 | +class AstraUploadStager(UploadStager): |
| 52 | + upload_stager_config: AstraUploadStagerConfig = field( |
| 53 | + default_factory=lambda: AstraUploadStagerConfig() |
| 54 | + ) |
| 55 | + |
| 56 | + def conform_dict(self, element_dict: dict) -> dict: |
| 57 | + return { |
| 58 | + "$vector": element_dict.pop("embeddings", None), |
| 59 | + "content": element_dict.pop("text", None), |
| 60 | + "metadata": element_dict, |
| 61 | + } |
| 62 | + |
| 63 | + def run( |
| 64 | + self, |
| 65 | + elements_filepath: Path, |
| 66 | + file_data: FileData, |
| 67 | + output_dir: Path, |
| 68 | + output_filename: str, |
| 69 | + **kwargs: Any, |
| 70 | + ) -> Path: |
| 71 | + with open(elements_filepath) as elements_file: |
| 72 | + elements_contents = json.load(elements_file) |
| 73 | + conformed_elements = [] |
| 74 | + for element in elements_contents: |
| 75 | + conformed_elements.append(self.conform_dict(element_dict=element)) |
| 76 | + output_path = Path(output_dir) / Path(f"{output_filename}.json") |
| 77 | + with open(output_path, "w") as output_file: |
| 78 | + json.dump(conformed_elements, output_file) |
| 79 | + return output_path |
| 80 | + |
| 81 | + |
| 82 | +@dataclass |
| 83 | +class AstraUploaderConfig(UploaderConfig): |
| 84 | + collection_name: str |
| 85 | + embedding_dimension: int |
| 86 | + namespace: Optional[str] = None |
| 87 | + requested_indexing_policy: Optional[dict[str, Any]] = None |
| 88 | + batch_size: int = 20 |
| 89 | + |
| 90 | + |
| 91 | +@dataclass |
| 92 | +class AstraUploader(Uploader): |
| 93 | + connection_config: AstraConnectionConfig |
| 94 | + upload_config: AstraUploaderConfig |
| 95 | + |
| 96 | + @requires_dependencies(["astrapy"], extras="astra") |
| 97 | + def get_collection(self) -> "AstraDBCollection": |
| 98 | + from astrapy.db import AstraDB |
| 99 | + |
| 100 | + # Get the collection_name and embedding dimension |
| 101 | + collection_name = self.upload_config.collection_name |
| 102 | + embedding_dimension = self.upload_config.embedding_dimension |
| 103 | + requested_indexing_policy = self.upload_config.requested_indexing_policy |
| 104 | + |
| 105 | + # If the user has requested an indexing policy, pass it to the AstraDB |
| 106 | + options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None |
| 107 | + |
| 108 | + # Build the Astra DB object. |
| 109 | + # caller_name/version for AstraDB tracking |
| 110 | + astra_db = AstraDB( |
| 111 | + api_endpoint=self.connection_config.access_config.api_endpoint, |
| 112 | + token=self.connection_config.access_config.token, |
| 113 | + namespace=self.upload_config.namespace, |
| 114 | + caller_name=integration_name, |
| 115 | + caller_version=integration_version, |
| 116 | + ) |
| 117 | + |
| 118 | + # Create and connect to the newly created collection |
| 119 | + astra_db_collection = astra_db.create_collection( |
| 120 | + collection_name=collection_name, |
| 121 | + dimension=embedding_dimension, |
| 122 | + options=options, |
| 123 | + ) |
| 124 | + return astra_db_collection |
| 125 | + |
| 126 | + def run(self, contents: list[UploadContent], **kwargs: Any) -> None: |
| 127 | + elements_dict = [] |
| 128 | + for content in contents: |
| 129 | + with open(content.path) as elements_file: |
| 130 | + elements = json.load(elements_file) |
| 131 | + elements_dict.extend(elements) |
| 132 | + |
| 133 | + logger.info( |
| 134 | + f"writing {len(elements_dict)} objects to destination " |
| 135 | + f"collection {self.upload_config.collection_name}" |
| 136 | + ) |
| 137 | + |
| 138 | + astra_batch_size = self.upload_config.batch_size |
| 139 | + collection = self.get_collection() |
| 140 | + |
| 141 | + for chunk in chunk_generator(elements_dict, astra_batch_size): |
| 142 | + collection.insert_many(chunk) |
| 143 | + |
| 144 | + |
| 145 | +add_destination_entry( |
| 146 | + destination_type=CONNECTOR_TYPE, |
| 147 | + entry=DestinationRegistryEntry( |
| 148 | + connection_config=AstraConnectionConfig, |
| 149 | + upload_stager_config=AstraUploadStagerConfig, |
| 150 | + upload_stager=AstraUploadStager, |
| 151 | + uploader_config=AstraUploaderConfig, |
| 152 | + uploader=AstraUploader, |
| 153 | + ), |
| 154 | +) |
0 commit comments