diff --git a/pyproject.toml b/pyproject.toml index 4b8f9616f..4b3bed5c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ vastdb = ["requirements/connectors/vastdb.txt"] vectara = ["requirements/connectors/vectara.txt"] weaviate = ["requirements/connectors/weaviate.txt"] wikipedia = ["requirements/connectors/wikipedia.txt"] +yugabytedb = ["requirements/connectors/yugabytedb.txt"] zendesk = ["requirements/connectors/zendesk.txt"] # Embedders diff --git a/requirements/connectors/yugabytedb.txt b/requirements/connectors/yugabytedb.txt new file mode 100644 index 000000000..2abc1718c --- /dev/null +++ b/requirements/connectors/yugabytedb.txt @@ -0,0 +1,2 @@ +pandas +psycopg2-yugabytedb \ No newline at end of file diff --git a/unstructured_ingest/processes/connectors/sql/__init__.py b/unstructured_ingest/processes/connectors/sql/__init__.py index 140300ea8..edafd4768 100644 --- a/unstructured_ingest/processes/connectors/sql/__init__.py +++ b/unstructured_ingest/processes/connectors/sql/__init__.py @@ -17,12 +17,15 @@ from .sqlite import sqlite_destination_entry, sqlite_source_entry from .vastdb import CONNECTOR_TYPE as VASTDB_CONNECTOR_TYPE from .vastdb import vastdb_destination_entry, vastdb_source_entry +from .yugabytedb import CONNECTOR_TYPE as YUGABYTE_DB_CONNECTOR_TYPE +from .yugabytedb import yugabytedb_destination_entry, yugabytedb_source_entry add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry) add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry) add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry) add_source_entry(source_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_source_entry) add_source_entry(source_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_source_entry) +add_source_entry(source_type=YUGABYTE_DB_CONNECTOR_TYPE, entry=yugabytedb_source_entry) add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry) add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry) @@ -35,3 +38,4 @@ entry=databricks_delta_tables_destination_entry, ) add_destination_entry(destination_type=VASTDB_CONNECTOR_TYPE, entry=vastdb_destination_entry) +add_destination_entry(destination_type=YUGABYTE_DB_CONNECTOR_TYPE, entry=yugabytedb_destination_entry) \ No newline at end of file diff --git a/unstructured_ingest/processes/connectors/sql/yugabytedb.py b/unstructured_ingest/processes/connectors/sql/yugabytedb.py new file mode 100644 index 000000000..d8e3ac12f --- /dev/null +++ b/unstructured_ingest/processes/connectors/sql/yugabytedb.py @@ -0,0 +1,175 @@ +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generator, Optional + +from pydantic import Field, Secret + +from unstructured_ingest.data_types.file_data import FileData +from unstructured_ingest.logger import logger +from unstructured_ingest.processes.connector_registry import ( + DestinationRegistryEntry, + SourceRegistryEntry, +) +from unstructured_ingest.processes.connectors.sql.sql import ( + SQLAccessConfig, + SqlBatchFileData, + SQLConnectionConfig, + SQLDownloader, + SQLDownloaderConfig, + SQLIndexer, + SQLIndexerConfig, + SQLUploader, + SQLUploaderConfig, + SQLUploadStager, + SQLUploadStagerConfig, +) +from unstructured_ingest.utils.dep_check import requires_dependencies + +if TYPE_CHECKING: + from psycopg2.extensions import connection as YugabyteDbConnection + from psycopg2.extensions import cursor as YugabyteDbCursor + +CONNECTOR_TYPE = "yugabytedb" + + +class YugabyteDbAccessConfig(SQLAccessConfig): + password: Optional[str] = Field(default=None, description="DB password") + + +class YugabyteDbConnectionConfig(SQLConnectionConfig): + access_config: Secret[YugabyteDbAccessConfig] = Field( + default=YugabyteDbAccessConfig(), validate_default=True + ) + database: Optional[str] = Field( + default=None, + description="Database name.", + ) + username: Optional[str] = Field(default=None, description="DB username") + host: Optional[str] = Field(default=None, description="DB host") + port: Optional[int] = Field(default=5432, description="DB host connection port") + load_balance: Optional[str] = Field(default="False", description="Load balancing strategy") + topology_keys: Optional[str] = Field(default="", description="Topology keys") + yb_servers_refresh_interval: Optional[int] = Field(default=300, + description="YB servers refresh interval") + connector_type: str = Field(default=CONNECTOR_TYPE, init=False) + + @contextmanager + @requires_dependencies(["psycopg2"], extras="yugabytedb") + def get_connection(self) -> Generator["YugabyteDbConnection", None, None]: + from psycopg2 import connect + + access_config = self.access_config.get_secret_value() + connection = connect( + user=self.username, + password=access_config.password, + dbname=self.database, + host=self.host, + port=self.port, + load_balance=self.load_balance, + topology_keys=self.topology_keys, + yb_servers_refresh_interval=self.yb_servers_refresh_interval, + ) + try: + yield connection + finally: + connection.commit() + connection.close() + + @contextmanager + def get_cursor(self) -> Generator["YugabyteDbCursor", None, None]: + with self.get_connection() as connection: + cursor = connection.cursor() + try: + yield cursor + finally: + cursor.close() + + +class YugabyteDbIndexerConfig(SQLIndexerConfig): + pass + + +@dataclass +class YugabyteDbIndexer(SQLIndexer): + connection_config: YugabyteDbConnectionConfig + index_config: YugabyteDbIndexerConfig + connector_type: str = CONNECTOR_TYPE + + +class YugabyteDbDownloaderConfig(SQLDownloaderConfig): + pass + + +@dataclass +class YugabyteDbDownloader(SQLDownloader): + connection_config: YugabyteDbConnectionConfig + download_config: YugabyteDbDownloaderConfig + connector_type: str = CONNECTOR_TYPE + + @requires_dependencies(["psycopg2"], extras="yugabytedb") + def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]: + from psycopg2 import sql + + table_name = file_data.additional_metadata.table_name + id_column = file_data.additional_metadata.id_column + ids = tuple([item.identifier for item in file_data.batch_items]) + + with self.connection_config.get_cursor() as cursor: + fields = ( + sql.SQL(",").join(sql.Identifier(field) for field in self.download_config.fields) + if self.download_config.fields + else sql.SQL("*") + ) + + query = sql.SQL("SELECT {fields} FROM {table_name} WHERE {id_column} IN %s").format( + fields=fields, + table_name=sql.Identifier(table_name), + id_column=sql.Identifier(id_column), + ) + logger.debug(f"running query: {cursor.mogrify(query, (ids,))}") + cursor.execute(query, (ids,)) + rows = cursor.fetchall() + columns = [col[0] for col in cursor.description] + return rows, columns + + +class YugabyteDbUploadStagerConfig(SQLUploadStagerConfig): + pass + + +class YugabyteDbUploadStager(SQLUploadStager): + upload_stager_config: YugabyteDbUploadStagerConfig + + +class YugabyteDbUploaderConfig(SQLUploaderConfig): + pass + + +@dataclass +class YugabyteDbUploader(SQLUploader): + upload_config: YugabyteDbUploaderConfig = field(default_factory=YugabyteDbUploaderConfig) + connection_config: YugabyteDbConnectionConfig + connector_type: str = CONNECTOR_TYPE + values_delimiter: str = "%s" + + @requires_dependencies(["pandas"], extras="yugabytedb") + def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None: + super().run(path=path, file_data=file_data, **kwargs) + + +yugabytedb_source_entry = SourceRegistryEntry( + connection_config=YugabyteDbConnectionConfig, + indexer_config=YugabyteDbIndexerConfig, + indexer=YugabyteDbIndexer, + downloader_config=YugabyteDbDownloaderConfig, + downloader=YugabyteDbDownloader, +) + +yugabytedb_destination_entry = DestinationRegistryEntry( + connection_config=YugabyteDbConnectionConfig, + uploader=YugabyteDbUploader, + uploader_config=YugabyteDbUploaderConfig, + upload_stager=YugabyteDbUploadStager, + upload_stager_config=YugabyteDbUploadStagerConfig, +)