Skip to content

Commit b3add2a

Browse files
rbiseck3badGarnet
andauthored
feat/support passing in metadata to write to pinecone payload (#156)
* support passing in metadata to write to pinecone payload * add test * lint * lint * tidy * fix: force help text to be str * fix help text * pass in str of tuple into field constructor * skip octoai test due to service not availaible --------- Co-authored-by: Yao You <[email protected]>
1 parent 277c993 commit b3add2a

File tree

6 files changed

+88
-26
lines changed

6 files changed

+88
-26
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.0.24
2+
3+
### Enhancements
4+
5+
* **Support dynamic metadata mapping in Pinecone uploader**
6+
17
## 0.0.23
28

39
### Fixes
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
3+
from unstructured_ingest.v2.processes.connectors.pinecone import (
4+
PineconeUploadStager,
5+
PineconeUploadStagerConfig,
6+
)
7+
8+
9+
@pytest.fixture
10+
def test_element_dict():
11+
return {
12+
"embeddings": [0, 1],
13+
"text": "test dict",
14+
"metadata": {
15+
"text_as_html": "text as html",
16+
"foo": "foo",
17+
},
18+
}
19+
20+
21+
@pytest.mark.parametrize(
22+
("metadata_fields", "expected_to_exist", "not_expected_to_exist"),
23+
[
24+
(None, ["text_as_html"], ["foo"]),
25+
(("foo",), ["foo"], ["text_as_html"]),
26+
],
27+
)
28+
def test_conform_dict(
29+
monkeypatch, test_element_dict, metadata_fields, expected_to_exist, not_expected_to_exist
30+
):
31+
if metadata_fields is not None:
32+
stager = PineconeUploadStager(
33+
upload_stager_config=PineconeUploadStagerConfig(metadata_fields=metadata_fields)
34+
)
35+
else:
36+
stager = PineconeUploadStager()
37+
results = stager.conform_dict(test_element_dict.copy())
38+
results.pop("id")
39+
assert test_element_dict["embeddings"] == results.pop("values")
40+
41+
assert all(
42+
results["metadata"][key] == test_element_dict["metadata"][key] for key in expected_to_exist
43+
)
44+
assert all(key not in results["metadata"] for key in not_expected_to_exist)

test_e2e/test-src.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ all_tests=(
6363
'hubspot.sh'
6464
'local-embed.sh'
6565
'local-embed-bedrock.sh'
66-
'local-embed-octoai.sh'
66+
# NOTE (yao): octoai url is giving 404
67+
# 'local-embed-octoai.sh'
6768
'local-embed-vertexai.sh'
6869
'local-embed-voyageai.sh'
6970
'local-embed-mixedbreadai.sh'

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.23" # pragma: no cover
1+
__version__ = "0.0.24" # pragma: no cover

unstructured_ingest/v2/cli/utils/model_conversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ def _get_type_from_field(field: FieldInfo) -> click.ParamType:
155155

156156
def get_option_from_field(option_name: str, field_info: FieldInfo) -> Option:
157157
param_decls = [option_name]
158-
help = field_info.description or ""
158+
help_text = field_info.description or ""
159159
if examples := field_info.examples:
160-
help += f" [Examples: {', '.join(examples)}]"
160+
help_text += f" [Examples: {', '.join(examples)}]"
161161
option_kwargs = {
162162
"type": _get_type_from_field(field_info),
163163
"default": get_default_value_from_field(field_info),
164164
"required": field_info.is_required(),
165-
"help": help,
165+
"help": str(help_text),
166166
"is_flag": is_boolean_flag(field_info),
167167
"show_default": field_info.default is not PydanticUndefined,
168168
}

unstructured_ingest/v2/processes/connectors/pinecone.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,6 @@ def get_index(self, **index_kwargs) -> "PineconeIndex":
5858
return index
5959

6060

61-
class PineconeUploadStagerConfig(UploadStagerConfig):
62-
pass
63-
64-
65-
class PineconeUploaderConfig(UploaderConfig):
66-
batch_size: Optional[int] = Field(
67-
default=None,
68-
description="Optional number of records per batch. Will otherwise limit by size.",
69-
)
70-
pool_threads: Optional[int] = Field(
71-
default=1, description="Optional limit on number of threads to use for upload"
72-
)
73-
74-
7561
ALLOWED_FIELDS = (
7662
"element_id",
7763
"text",
@@ -86,31 +72,56 @@ class PineconeUploaderConfig(UploaderConfig):
8672
"is_continuation",
8773
"link_urls",
8874
"link_texts",
75+
"text_as_html",
8976
)
9077

9178

79+
class PineconeUploadStagerConfig(UploadStagerConfig):
80+
metadata_fields: list[str] = Field(
81+
default=str(ALLOWED_FIELDS),
82+
description=(
83+
"which metadata from the source element to map to the payload metadata being sent to "
84+
"Pinecone."
85+
),
86+
)
87+
88+
89+
class PineconeUploaderConfig(UploaderConfig):
90+
batch_size: Optional[int] = Field(
91+
default=None,
92+
description="Optional number of records per batch. Will otherwise limit by size.",
93+
)
94+
pool_threads: Optional[int] = Field(
95+
default=1, description="Optional limit on number of threads to use for upload"
96+
)
97+
98+
9299
@dataclass
93100
class PineconeUploadStager(UploadStager):
94101
upload_stager_config: PineconeUploadStagerConfig = field(
95102
default_factory=lambda: PineconeUploadStagerConfig()
96103
)
97104

98-
@staticmethod
99-
def conform_dict(element_dict: dict) -> dict:
105+
def conform_dict(self, element_dict: dict) -> dict:
100106
embeddings = element_dict.pop("embeddings", None)
101107
metadata: dict[str, Any] = element_dict.pop("metadata", {})
102108
data_source = metadata.pop("data_source", {})
103109
coordinates = metadata.pop("coordinates", {})
104-
105-
element_dict.update(metadata)
106-
element_dict.update(data_source)
107-
element_dict.update(coordinates)
110+
pinecone_metadata = {}
111+
for possible_meta in [element_dict, metadata, data_source, coordinates]:
112+
pinecone_metadata.update(
113+
{
114+
k: v
115+
for k, v in possible_meta.items()
116+
if k in self.upload_stager_config.metadata_fields
117+
}
118+
)
108119

109120
return {
110121
"id": str(uuid.uuid4()),
111122
"values": embeddings,
112123
"metadata": flatten_dict(
113-
{k: v for k, v in element_dict.items() if k in ALLOWED_FIELDS},
124+
pinecone_metadata,
114125
separator="-",
115126
flatten_lists=True,
116127
remove_none=True,

0 commit comments

Comments
 (0)