Skip to content

Commit 0d37556

Browse files
ahmetmeleqrbiseck3
andauthored
feat: add destination creation support to pinecone (#361)
* update test and implementation changelog and version wip * improvements * default vector length * fix * allow pipeline to pass init data downstream * add format destination name * version * decorator * tidy --------- Co-authored-by: Roman Isecke <[email protected]>
1 parent 5436ceb commit 0d37556

File tree

9 files changed

+162
-44
lines changed

9 files changed

+162
-44
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
## 0.5.6-dev0
1+
## 0.5.6
22

33
### Enhancements
44

5+
* **Add support for setting up destination for Pinecone**
56
* Add name formatting to Weaviate destination uploader
67

78
## 0.5.5

test/integration/connectors/test_pinecone.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,37 @@ def test_pinecone_stager(
351351
stager=stager,
352352
tmp_dir=tmp_path,
353353
)
354+
355+
356+
@requires_env(API_KEY)
357+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
358+
def test_pinecone_create_destination(pinecone_index):
359+
uploader = PineconeUploader(
360+
connection_config=PineconeConnectionConfig(
361+
access_config=PineconeAccessConfig(api_key=get_api_key())
362+
),
363+
upload_config=PineconeUploaderConfig(),
364+
)
365+
366+
random_id = str(uuid4()).split("-")[0]
367+
368+
index_name = f"test-create-destination-{random_id}"
369+
370+
assert not uploader.index_exists(index_name=index_name)
371+
372+
try:
373+
uploader.create_destination(destination_name=index_name, vector_length=1536)
374+
except Exception as e:
375+
error_body = getattr(e, "body", None)
376+
raise pytest.fail(f"failed to create destination: {e} {error_body}")
377+
378+
assert uploader.index_exists(index_name=index_name), "destination was not created successfully"
379+
380+
try:
381+
pc = uploader.connection_config.get_client()
382+
logger.info(f"deleting index for test create destination: {index_name}")
383+
pc.delete_index(name=index_name)
384+
except Exception as e:
385+
raise pytest.fail(f"failed to cleanup / delete the destination: {e}")
386+
387+
assert not uploader.index_exists(index_name=index_name), "cleanup failed"

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.6-dev0" # pragma: no cover
1+
__version__ = "0.5.6" # pragma: no cover

unstructured_ingest/v2/interfaces/process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class BaseProcess(ABC):
88
def is_async(self) -> bool:
99
return False
1010

11-
def init(self, *kwargs: Any) -> None:
11+
def init(self, **kwargs: Any) -> None:
1212
pass
1313

1414
def precheck(self) -> None:

unstructured_ingest/v2/interfaces/uploader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC
22
from dataclasses import dataclass
33
from pathlib import Path
4-
from typing import Any, Optional, TypeVar
4+
from typing import Any, TypeVar
55

66
from pydantic import BaseModel
77

@@ -61,6 +61,6 @@ async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs:
6161
@dataclass
6262
class VectorDBUploader(Uploader, ABC):
6363
def create_destination(
64-
self, destination_name: str = "elements", vector_length: Optional[int] = None, **kwargs: Any
64+
self, vector_length: int, destination_name: str = "elements", **kwargs: Any
6565
) -> bool:
6666
return False

unstructured_ingest/v2/pipeline/pipeline.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,32 @@ def log_statuses(self):
126126
for kk, vv in v.items():
127127
logger.error(f"{k}: [{kk}] {vv}")
128128

129+
def _run_initialization(self):
130+
failures = {}
131+
init_kwargs = {}
132+
for step in self._get_ordered_steps():
133+
try:
134+
step.process.init(**init_kwargs)
135+
step.process.precheck()
136+
# Make sure embedder dimensions available for downstream steps
137+
if isinstance(step.process, Embedder):
138+
embed_dimensions = step.process.config.get_embedder().dimension
139+
init_kwargs["vector_length"] = embed_dimensions
140+
141+
except Exception as e:
142+
failures[step.process.__class__.__name__] = f"[{type(e).__name__}] {e}"
143+
if failures:
144+
for k, v in failures.items():
145+
logger.error(f"Step initialization failure: {k}: {v}")
146+
raise PipelineError("Initialization failed")
147+
129148
def run(self):
130149
otel_handler = OtelHandler(otel_endpoint=self.context.otel_endpoint, log_out=logger.info)
131150
try:
132151
with otel_handler.get_tracer().start_as_current_span(
133152
"ingest process", record_exception=True
134153
):
135-
self._run_inits()
136-
self._run_prechecks()
154+
self._run_initialization()
137155
self._run()
138156
finally:
139157
self.log_statuses()
@@ -154,43 +172,20 @@ def clean_results(self, results: list[Any | list[Any]] | None) -> list[Any] | No
154172
final = [f for f in flat if f]
155173
return final or None
156174

157-
def _get_all_steps(self) -> list[PipelineStep]:
158-
steps = [self.indexer_step, self.downloader_step, self.partitioner_step, self.uploader_step]
175+
def _get_ordered_steps(self) -> list[PipelineStep]:
176+
steps = [self.indexer_step, self.downloader_step]
177+
if self.uncompress_step:
178+
steps.append(self.uncompress_step)
179+
steps.append(self.partitioner_step)
159180
if self.chunker_step:
160181
steps.append(self.chunker_step)
161182
if self.embedder_step:
162183
steps.append(self.embedder_step)
163-
if self.uncompress_step:
164-
steps.append(self.uncompress_step)
165184
if self.stager_step:
166185
steps.append(self.stager_step)
186+
steps.append(self.uploader_step)
167187
return steps
168188

169-
def _run_inits(self):
170-
failures = {}
171-
172-
for step in self._get_all_steps():
173-
try:
174-
step.process.init()
175-
except Exception as e:
176-
failures[step.process.__class__.__name__] = f"[{type(e).__name__}] {e}"
177-
if failures:
178-
for k, v in failures.items():
179-
logger.error(f"Step init failure: {k}: {v}")
180-
raise PipelineError("Init failed")
181-
182-
def _run_prechecks(self):
183-
failures = {}
184-
for step in self._get_all_steps():
185-
try:
186-
step.process.precheck()
187-
except Exception as e:
188-
failures[step.process.__class__.__name__] = f"[{type(e).__name__}] {e}"
189-
if failures:
190-
for k, v in failures.items():
191-
logger.error(f"Step precheck failure: {k}: {v}")
192-
raise PipelineError("Precheck failed")
193-
194189
def apply_filter(self, records: list[dict]) -> list[dict]:
195190
if not self.filter_step:
196191
return records

unstructured_ingest/v2/processes/connectors/pinecone.py

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
2+
import re
23
from dataclasses import dataclass, field
3-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import TYPE_CHECKING, Any, Literal, Optional
45

56
from pydantic import Field, Secret
67

@@ -13,10 +14,10 @@
1314
AccessConfig,
1415
ConnectionConfig,
1516
FileData,
16-
Uploader,
1717
UploaderConfig,
1818
UploadStager,
1919
UploadStagerConfig,
20+
VectorDBUploader,
2021
)
2122
from unstructured_ingest.v2.logger import logger
2223
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
@@ -41,7 +42,7 @@ class PineconeAccessConfig(AccessConfig):
4142

4243

4344
class PineconeConnectionConfig(ConnectionConfig):
44-
index_name: str = Field(description="Name of the index to connect to.")
45+
index_name: Optional[str] = Field(description="Name of the index to connect to.", default=None)
4546
access_config: Secret[PineconeAccessConfig] = Field(
4647
default=PineconeAccessConfig(), validate_default=True
4748
)
@@ -160,18 +161,101 @@ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
160161

161162

162163
@dataclass
163-
class PineconeUploader(Uploader):
164+
class PineconeUploader(VectorDBUploader):
164165
upload_config: PineconeUploaderConfig
165166
connection_config: PineconeConnectionConfig
166167
connector_type: str = CONNECTOR_TYPE
167168

169+
def init(self, **kwargs: Any) -> None:
170+
self.create_destination(**kwargs)
171+
172+
def index_exists(self, index_name: Optional[str]) -> bool:
173+
from pinecone.exceptions import NotFoundException
174+
175+
index_name = index_name or self.connection_config.index_name
176+
pc = self.connection_config.get_client()
177+
try:
178+
pc.describe_index(index_name)
179+
return True
180+
except NotFoundException:
181+
return False
182+
except Exception as e:
183+
logger.error(f"failed to check if pinecone index exists : {e}")
184+
raise DestinationConnectionError(f"failed to check if pinecone index exists : {e}")
185+
168186
def precheck(self):
169187
try:
170-
self.connection_config.get_index()
188+
# just a connection check here. not an actual index_exists check
189+
self.index_exists("just-checking-our-connection")
190+
191+
if self.connection_config.index_name and not self.index_exists(
192+
self.connection_config.index_name
193+
):
194+
raise DestinationConnectionError(
195+
f"index {self.connection_config.index_name} does not exist"
196+
)
171197
except Exception as e:
172198
logger.error(f"failed to validate connection: {e}", exc_info=True)
173199
raise DestinationConnectionError(f"failed to validate connection: {e}")
174200

201+
def format_destination_name(self, destination_name: str) -> str:
202+
# Pinecone naming requirements:
203+
# can only contain lowercase letters, numbers, and hyphens
204+
# must be 45 characters or less
205+
formatted = re.sub(r"[^a-z0-9]", "-", destination_name.lower())
206+
return formatted
207+
208+
def create_destination(
209+
self,
210+
vector_length: int,
211+
destination_name: str = "elements",
212+
destination_type: Literal["pod", "serverless"] = "serverless",
213+
serverless_cloud: str = "aws",
214+
serverless_region: str = "us-west-2",
215+
pod_environment: str = "us-east1-gcp",
216+
pod_type: str = "p1.x1",
217+
pod_count: int = 1,
218+
**kwargs: Any,
219+
) -> bool:
220+
from pinecone import PodSpec, ServerlessSpec
221+
222+
index_name = destination_name or self.connection_config.index_name
223+
index_name = self.format_destination_name(index_name)
224+
self.connection_config.index_name = index_name
225+
226+
if not self.index_exists(index_name):
227+
228+
logger.info(f"creating pinecone index {index_name}")
229+
230+
pc = self.connection_config.get_client()
231+
232+
if destination_type == "serverless":
233+
pc.create_index(
234+
name=destination_name,
235+
dimension=vector_length,
236+
spec=ServerlessSpec(cloud=serverless_cloud, region=serverless_region),
237+
**kwargs,
238+
)
239+
240+
return True
241+
242+
elif destination_type == "pod":
243+
pc.create_index(
244+
name=destination_name,
245+
dimension=vector_length,
246+
spec=PodSpec(environment=pod_environment, pod_type=pod_type, pods=pod_count),
247+
**kwargs,
248+
)
249+
250+
return True
251+
252+
else:
253+
raise ValueError(f"unexpected destination type: {destination_type}")
254+
255+
else:
256+
logger.debug(f"index {index_name} already exists, skipping creation")
257+
return False
258+
175259
def pod_delete_by_record_id(self, file_data: FileData) -> None:
176260
logger.debug(
177261
f"deleting any content with metadata "
@@ -266,6 +350,10 @@ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None
266350
)
267351
# Determine if serverless or pod based index
268352
pinecone_client = self.connection_config.get_client()
353+
354+
if not self.connection_config.index_name:
355+
raise ValueError("No index name specified")
356+
269357
index_description = pinecone_client.describe_index(name=self.connection_config.index_name)
270358
if "serverless" in index_description.get("spec"):
271359
self.serverless_delete_by_record_id(file_data=file_data)

unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def precheck(self) -> None:
230230
logger.error(f"Failed to validate connection {e}", exc_info=True)
231231
raise DestinationConnectionError(f"failed to validate connection: {e}")
232232

233-
def init(self, *kwargs: Any) -> None:
234-
self.create_destination()
233+
def init(self, **kwargs: Any) -> None:
234+
self.create_destination(**kwargs)
235235

236236
def format_destination_name(self, destination_name: str) -> str:
237237
# Weaviate naming requirements:

unstructured_ingest/v2/processes/embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_embedder(self) -> "BaseEmbeddingEncoder":
186186
class Embedder(BaseProcess, ABC):
187187
config: EmbedderConfig
188188

189-
def init(self, *kwargs: Any) -> None:
189+
def init(self, **kwargs: Any) -> None:
190190
self.config.get_embedder().initialize()
191191

192192
def run(self, elements_filepath: Path, **kwargs: Any) -> list[dict]:

0 commit comments

Comments
 (0)