Skip to content

Commit a0923db

Browse files
feat/redis destination connector (#244)
* Add redis connector and integration test * Add package dependency to requirements files * update changelog and version * reformat files with isort and black * update based on flake8 instruction * reformat with black * reformat using ruff check * rename file to avoid conflict with redis-py package * fix syntax error * reformat code * load requirement packages for redis connector * fix package dependencies * change the way to get env var * add requires_env decorator * add Azure redis credential to github workflow * remove sync function; use context manager for client; remove unnecessary stager * update argument typing * move value validation to connection config * update async logic with asyncio.gather * clean up redis records after integration test * use redis client as context manager * use pydantic model validator to validate argument * remove asyncio.run to avoid conflict with async context * reflect new structure of uploader * bump up version * trigger ci cd pipeline
1 parent 51dfbe1 commit a0923db

File tree

8 files changed

+327
-0
lines changed

8 files changed

+327
-0
lines changed

.github/workflows/e2e.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ jobs:
156156
ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_ENDPOINT }}
157157
AZURE_SEARCH_ENDPOINT: ${{ secrets.AZURE_SEARCH_ENDPOINT }}
158158
AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }}
159+
AZURE_REDIS_INGEST_TEST_PASSWORD: ${{ secrets.AZURE_REDIS_INGEST_TEST_PASSWORD }}
159160
MONGODB_URI: ${{ secrets.MONGODB_URI }}
160161
MONGODB_DATABASE: ${{ secrets.MONGODB_DATABASE_NAME }}
161162
QDRANT_API_KEY: ${{ secrets.QDRANT_API_KEY }}
@@ -294,6 +295,7 @@ jobs:
294295
S3_INGEST_TEST_SECRET_KEY: ${{ secrets.S3_INGEST_TEST_SECRET_KEY }}
295296
AZURE_SEARCH_ENDPOINT: ${{ secrets.AZURE_SEARCH_ENDPOINT }}
296297
AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }}
298+
AZURE_REDIS_INGEST_TEST_PASSWORD: ${{ secrets.AZURE_REDIS_INGEST_TEST_PASSWORD }}
297299
BOX_APP_CONFIG: ${{ secrets.BOX_APP_CONFIG }}
298300
DROPBOX_APP_KEY: ${{ secrets.DROPBOX_APP_KEY }}
299301
DROPBOX_APP_SECRET: ${{ secrets.DROPBOX_APP_SECRET }}

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
## 0.3.12-dev2
22

3+
### Enhancements
4+
5+
* **Added Redis destination connector**
6+
7+
## 0.3.12-dev1
8+
9+
* **Bypass asyncio exception grouping to return more meaningful errors from OneDrive indexer**
10+
11+
## 0.3.12-dev0
12+
313
### Fixes
414

515
* **Fix Kafka destination connection problems**

requirements/connectors/redis.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-c ../common/constraints.txt
2+
3+
redis

requirements/connectors/redis.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# This file was autogenerated by uv via the following command:
2+
# uv pip compile ./requirements/connectors/redis.in --output-file ./requirements/connectors/redis.txt --no-strip-extras --python-version 3.9
3+
async-timeout==5.0.1
4+
# via redis
5+
redis==5.2.0
6+
# via -r ./requirements/connectors/redis.in

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def load_requirements(file: Union[str, Path]) -> List[str]:
117117
"postgres": load_requirements("requirements/connectors/postgres.in"),
118118
"qdrant": load_requirements("requirements/connectors/qdrant.in"),
119119
"reddit": load_requirements("requirements/connectors/reddit.in"),
120+
"redis": load_requirements("requirements/connectors/redis.in"),
120121
"s3": load_requirements("requirements/connectors/s3.in"),
121122
"sharepoint": load_requirements("requirements/connectors/sharepoint.in"),
122123
"salesforce": load_requirements("requirements/connectors/salesforce.in"),
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import asyncio
2+
import json
3+
import os
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
import numpy as np
8+
import pytest
9+
from redis import exceptions as redis_exceptions
10+
from redis.asyncio import Redis, from_url
11+
12+
from test.integration.connectors.utils.constants import DESTINATION_TAG
13+
from test.integration.utils import requires_env
14+
from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
15+
from unstructured_ingest.v2.processes.connectors.redisdb import (
16+
CONNECTOR_TYPE as REDIS_CONNECTOR_TYPE,
17+
)
18+
from unstructured_ingest.v2.processes.connectors.redisdb import (
19+
RedisAccessConfig,
20+
RedisConnectionConfig,
21+
RedisUploader,
22+
RedisUploaderConfig,
23+
)
24+
25+
26+
async def delete_record(client: Redis, element_id: str) -> None:
27+
await client.delete(element_id)
28+
29+
30+
async def validate_upload(client: Redis, first_element: dict):
31+
element_id = first_element["element_id"]
32+
expected_text = first_element["text"]
33+
expected_embeddings = first_element["embeddings"]
34+
async with client.pipeline(transaction=True) as pipe:
35+
try:
36+
response = await pipe.json().get(element_id, "$").execute()
37+
response = response[0][0]
38+
except redis_exceptions.ResponseError:
39+
response = await pipe.get(element_id).execute()
40+
response = json.loads(response[0])
41+
42+
embedding_similarity = np.linalg.norm(
43+
np.array(response["embeddings"]) - np.array(expected_embeddings)
44+
)
45+
46+
assert response is not None
47+
assert response["element_id"] == element_id
48+
assert response["text"] == expected_text
49+
assert embedding_similarity < 1e-10
50+
51+
52+
async def redis_destination_test(
53+
upload_file: Path,
54+
tmp_path: Path,
55+
connection_kwargs: dict,
56+
uri: Optional[str] = None,
57+
password: Optional[str] = None,
58+
):
59+
uploader = RedisUploader(
60+
connection_config=RedisConnectionConfig(
61+
**connection_kwargs, access_config=RedisAccessConfig(uri=uri, password=password)
62+
),
63+
upload_config=RedisUploaderConfig(batch_size=10),
64+
)
65+
66+
file_data = FileData(
67+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
68+
connector_type=REDIS_CONNECTOR_TYPE,
69+
identifier="mock-file-data",
70+
)
71+
with upload_file.open() as upload_fp:
72+
elements = json.load(upload_fp)
73+
first_element = elements[0]
74+
75+
try:
76+
if uploader.is_async():
77+
await uploader.run_data_async(data=elements, file_data=file_data)
78+
79+
if uri:
80+
async with from_url(uri) as client:
81+
await validate_upload(client=client, first_element=first_element)
82+
else:
83+
async with Redis(**connection_kwargs, password=password) as client:
84+
await validate_upload(client=client, first_element=first_element)
85+
except Exception as e:
86+
raise e
87+
finally:
88+
if uri:
89+
async with from_url(uri) as client:
90+
tasks = [delete_record(client, element["element_id"]) for element in elements]
91+
await asyncio.gather(*tasks)
92+
else:
93+
async with Redis(**connection_kwargs, password=password) as client:
94+
tasks = [delete_record(client, element["element_id"]) for element in elements]
95+
await asyncio.gather(*tasks)
96+
97+
98+
@pytest.mark.asyncio
99+
@pytest.mark.tags(REDIS_CONNECTOR_TYPE, DESTINATION_TAG)
100+
@requires_env("AZURE_REDIS_INGEST_TEST_PASSWORD")
101+
async def test_redis_destination_azure_with_password(upload_file: Path, tmp_path: Path):
102+
connection_kwargs = {
103+
"host": "utic-dashboard-dev.redis.cache.windows.net",
104+
"port": 6380,
105+
"db": 0,
106+
"ssl": True,
107+
}
108+
redis_pw = os.environ["AZURE_REDIS_INGEST_TEST_PASSWORD"]
109+
await redis_destination_test(upload_file, tmp_path, connection_kwargs, password=redis_pw)
110+
111+
112+
@pytest.mark.asyncio
113+
@pytest.mark.tags(REDIS_CONNECTOR_TYPE, DESTINATION_TAG, "redis")
114+
@requires_env("AZURE_REDIS_INGEST_TEST_PASSWORD")
115+
async def test_redis_destination_azure_with_uri(upload_file: Path, tmp_path: Path):
116+
connection_kwargs = {}
117+
redis_pw = os.environ["AZURE_REDIS_INGEST_TEST_PASSWORD"]
118+
uri = f"rediss://:{redis_pw}@utic-dashboard-dev.redis.cache.windows.net:6380/0"
119+
await redis_destination_test(upload_file, tmp_path, connection_kwargs, uri=uri)

unstructured_ingest/v2/processes/connectors/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from .outlook import outlook_source_entry
4949
from .pinecone import CONNECTOR_TYPE as PINECONE_CONNECTOR_TYPE
5050
from .pinecone import pinecone_destination_entry
51+
from .redisdb import CONNECTOR_TYPE as REDIS_CONNECTOR_TYPE
52+
from .redisdb import redis_destination_entry
5153
from .salesforce import CONNECTOR_TYPE as SALESFORCE_CONNECTOR_TYPE
5254
from .salesforce import salesforce_source_entry
5355
from .sharepoint import CONNECTOR_TYPE as SHAREPOINT_CONNECTOR_TYPE
@@ -102,3 +104,5 @@
102104
add_source_entry(source_type=SLACK_CONNECTOR_TYPE, entry=slack_source_entry)
103105

104106
add_source_entry(source_type=CONFLUENCE_CONNECTOR_TYPE, entry=confluence_source_entry)
107+
108+
add_destination_entry(destination_type=REDIS_CONNECTOR_TYPE, entry=redis_destination_entry)
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import json
2+
from contextlib import asynccontextmanager, contextmanager
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional
5+
6+
from pydantic import Field, Secret, model_validator
7+
8+
from unstructured_ingest.error import DestinationConnectionError
9+
from unstructured_ingest.utils.data_prep import batch_generator
10+
from unstructured_ingest.utils.dep_check import requires_dependencies
11+
from unstructured_ingest.v2.interfaces import (
12+
AccessConfig,
13+
ConnectionConfig,
14+
FileData,
15+
Uploader,
16+
UploaderConfig,
17+
)
18+
from unstructured_ingest.v2.logger import logger
19+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
20+
21+
if TYPE_CHECKING:
22+
from redis.asyncio import Redis
23+
24+
import asyncio
25+
26+
CONNECTOR_TYPE = "redis"
27+
SERVER_API_VERSION = "1"
28+
29+
30+
class RedisAccessConfig(AccessConfig):
31+
uri: Optional[str] = Field(
32+
default=None, description="If not anonymous, use this uri, if specified."
33+
)
34+
password: Optional[str] = Field(
35+
default=None, description="If not anonymous, use this password, if specified."
36+
)
37+
38+
39+
class RedisConnectionConfig(ConnectionConfig):
40+
access_config: Secret[RedisAccessConfig] = Field(
41+
default=RedisAccessConfig(), validate_default=True
42+
)
43+
host: Optional[str] = Field(
44+
default=None, description="Hostname or IP address of a Redis instance to connect to."
45+
)
46+
database: int = Field(default=0, description="Database index to connect to.")
47+
port: int = Field(default=6379, description="port used to connect to database.")
48+
username: Optional[str] = Field(
49+
default=None, description="Username used to connect to database."
50+
)
51+
ssl: bool = Field(default=True, description="Whether the connection should use SSL encryption.")
52+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
53+
54+
@model_validator(mode="after")
55+
def validate_host_or_url(self) -> "RedisConnectionConfig":
56+
if not self.access_config.get_secret_value().uri and not self.host:
57+
raise ValueError("Please pass a hostname either directly or through uri")
58+
return self
59+
60+
@requires_dependencies(["redis"], extras="redis")
61+
@asynccontextmanager
62+
async def create_async_client(self) -> AsyncGenerator["Redis", None]:
63+
from redis.asyncio import Redis, from_url
64+
65+
access_config = self.access_config.get_secret_value()
66+
67+
options = {
68+
"host": self.host,
69+
"port": self.port,
70+
"db": self.database,
71+
"ssl": self.ssl,
72+
"username": self.username,
73+
}
74+
75+
if access_config.password:
76+
options["password"] = access_config.password
77+
78+
if access_config.uri:
79+
async with from_url(access_config.uri) as client:
80+
yield client
81+
else:
82+
async with Redis(**options) as client:
83+
yield client
84+
85+
@requires_dependencies(["redis"], extras="redis")
86+
@contextmanager
87+
def create_client(self) -> Generator["Redis", None, None]:
88+
from redis import Redis, from_url
89+
90+
access_config = self.access_config.get_secret_value()
91+
92+
options = {
93+
"host": self.host,
94+
"port": self.port,
95+
"db": self.database,
96+
"ssl": self.ssl,
97+
"username": self.username,
98+
}
99+
100+
if access_config.password:
101+
options["password"] = access_config.password
102+
103+
if access_config.uri:
104+
with from_url(access_config.uri) as client:
105+
yield client
106+
else:
107+
with Redis(**options) as client:
108+
yield client
109+
110+
111+
class RedisUploaderConfig(UploaderConfig):
112+
batch_size: int = Field(default=100, description="Number of records per batch")
113+
114+
115+
@dataclass
116+
class RedisUploader(Uploader):
117+
upload_config: RedisUploaderConfig
118+
connection_config: RedisConnectionConfig
119+
connector_type: str = CONNECTOR_TYPE
120+
121+
def is_async(self) -> bool:
122+
return True
123+
124+
def precheck(self) -> None:
125+
try:
126+
with self.connection_config.create_client() as client:
127+
client.ping()
128+
except Exception as e:
129+
logger.error(f"failed to validate connection: {e}", exc_info=True)
130+
raise DestinationConnectionError(f"failed to validate connection: {e}")
131+
132+
async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
133+
first_element = data[0]
134+
redis_stack = await self._check_redis_stack(first_element)
135+
logger.info(
136+
f"writing {len(data)} objects to destination asynchronously, "
137+
f"db, {self.connection_config.database}, "
138+
f"at {self.connection_config.host}",
139+
)
140+
141+
batches = list(batch_generator(data, batch_size=self.upload_config.batch_size))
142+
await asyncio.gather(*[self._write_batch(batch, redis_stack) for batch in batches])
143+
144+
async def _write_batch(self, batch: list[dict], redis_stack: bool) -> None:
145+
async with self.connection_config.create_async_client() as async_client:
146+
async with async_client.pipeline(transaction=True) as pipe:
147+
for element in batch:
148+
element_id = element["element_id"]
149+
if redis_stack:
150+
pipe.json().set(element_id, "$", element)
151+
else:
152+
pipe.set(element_id, json.dumps(element))
153+
await pipe.execute()
154+
155+
@requires_dependencies(["redis"], extras="redis")
156+
async def _check_redis_stack(self, element: dict) -> bool:
157+
from redis import exceptions as redis_exceptions
158+
159+
redis_stack = True
160+
async with self.connection_config.create_async_client() as async_client:
161+
async with async_client.pipeline(transaction=True) as pipe:
162+
element_id = element["element_id"]
163+
try:
164+
# Redis with stack extension supports JSON type
165+
await pipe.json().set(element_id, "$", element).execute()
166+
except redis_exceptions.ResponseError as e:
167+
message = str(e)
168+
if "unknown command `JSON.SET`" in message:
169+
# if this error occurs, Redis server doesn't support JSON type,
170+
# so save as string type instead
171+
await pipe.set(element_id, json.dumps(element)).execute()
172+
redis_stack = False
173+
else:
174+
raise e
175+
return redis_stack
176+
177+
178+
redis_destination_entry = DestinationRegistryEntry(
179+
connection_config=RedisConnectionConfig,
180+
uploader=RedisUploader,
181+
uploader_config=RedisUploaderConfig,
182+
)

0 commit comments

Comments
 (0)