Skip to content

Commit ce1b1a5

Browse files
add embed function for voyageai multimodal embedder (#393)
* add embed function for voyageai multimodal embedder * update changelog and version * fix mixedbread ai embed unit test
1 parent 0d37556 commit ce1b1a5

File tree

6 files changed

+39
-7
lines changed

6 files changed

+39
-7
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.5.6-dev1
2+
3+
### Fixes
4+
5+
* **Fix voyageai embedder: add multimodal embedder function**
6+
17
## 0.5.6
28

39
### Enhancements

test/integration/embedders/test_mixedbread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_raw_mixedbread_embedder(embedder_file: Path):
5050
embedder=embedder,
5151
embedder_file=embedder_file,
5252
expected_dimension=1024,
53-
expected_is_unit_vector=False,
53+
expected_is_unit_vector=True,
5454
)
5555

5656

@@ -67,5 +67,5 @@ async def test_raw_async_mixedbread_embedder(embedder_file: Path):
6767
embedder=embedder,
6868
embedder_file=embedder_file,
6969
expected_dimension=1024,
70-
expected_is_unit_vector=False,
70+
expected_is_unit_vector=True,
7171
)

test/integration/embedders/test_voyageai.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,19 @@ async def test_raw_async_voyageai_embedder(embedder_file: Path):
6161
await validate_raw_embedder_async(
6262
embedder=embedder, embedder_file=embedder_file, expected_dimension=1024
6363
)
64+
65+
66+
@requires_env(API_KEY)
67+
def test_voyageai_multimodal_embedder(embedder_file: Path):
68+
api_key = get_api_key()
69+
embedder_config = EmbedderConfig(
70+
embedding_provider="voyageai",
71+
embedding_api_key=api_key,
72+
embedding_model_name="voyage-multimodal-3",
73+
)
74+
embedder = Embedder(config=embedder_config)
75+
results = embedder.run(elements_filepath=embedder_file)
76+
assert results
77+
with embedder_file.open("r") as f:
78+
original_elements = json.load(f)
79+
validate_embedding_output(original_elements=original_elements, output_elements=results)

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.6" # pragma: no cover
1+
__version__ = "0.5.6-dev1" # pragma: no cover

unstructured_ingest/embed/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_exemplary_embedding(self) -> list[float]:
4949
def is_unit_vector(self) -> bool:
5050
"""Denotes if the embedding vector is a unit vector."""
5151
exemplary_embedding = self.get_exemplary_embedding()
52-
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
52+
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
5353

5454
def get_client(self):
5555
raise NotImplementedError
@@ -103,7 +103,7 @@ async def get_exemplary_embedding(self) -> list[float]:
103103
async def is_unit_vector(self) -> bool:
104104
"""Denotes if the embedding vector is a unit vector."""
105105
exemplary_embedding = await self.get_exemplary_embedding()
106-
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
106+
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
107107

108108
def get_client(self):
109109
raise NotImplementedError

unstructured_ingest/embed/voyageai.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def get_client(self) -> "VoyageAIClient":
9696
return self.config.get_client()
9797

9898
def embed_batch(self, client: "VoyageAIClient", batch: list[str]) -> list[list[float]]:
99-
response = client.embed(texts=batch, model=self.config.embedder_model_name)
99+
if self.config.embedder_model_name == "voyage-multimodal-3":
100+
batch = [[text] for text in batch]
101+
response = client.multimodal_embed(inputs=batch, model=self.config.embedder_model_name)
102+
else:
103+
response = client.embed(texts=batch, model=self.config.embedder_model_name)
100104
return response.embeddings
101105

102106

@@ -113,5 +117,11 @@ def get_client(self) -> "AsyncVoyageAIClient":
113117
async def embed_batch(
114118
self, client: "AsyncVoyageAIClient", batch: list[str]
115119
) -> list[list[float]]:
116-
response = await client.embed(texts=batch, model=self.config.embedder_model_name)
120+
if self.config.embedder_model_name == "voyage-multimodal-3":
121+
batch = [[text] for text in batch]
122+
response = await client.multimodal_embed(
123+
inputs=batch, model=self.config.embedder_model_name
124+
)
125+
else:
126+
response = await client.embed(texts=batch, model=self.config.embedder_model_name)
117127
return response.embeddings

0 commit comments

Comments
 (0)