Skip to content

Commit ae7bd87

Browse files
migrate to new mixedbread python sdk (#492)
* migrate to new mixedbread python sdk * fix lint * update changelog * update mixedbread unit test * update dependency in pyproject * fix fixedbread unit test
1 parent 4ec3834 commit ae7bd87

File tree

7 files changed

+558
-548
lines changed

7 files changed

+558
-548
lines changed

CHANGELOG.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
## 1.0.22-dev1
2+
3+
* **Migrate to new Mixedbread Python SDK**
4+
15
## 1.0.22
26

37
* **Fix Notion connector missing database properties fields**
48

59
## 1.0.21
610

711
* **Fix Jira connector cloud option not working issue**
8-
9-
## 1.0.20
10-
1112
* **Fix Weaviate connector issue with names being wrongly transformed to match collections naming conventions**
1213

1314
## 1.0.19
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-c ../common/constraints.txt
22

3-
mixedbread-ai
3+
mixedbread

test/integration/embedders/test_mixedbread.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def get_api_key() -> str:
2727

2828

2929
@requires_env(API_KEY)
30-
@pytest.mark.skip(reason="need to migrate to new mixedbread sdk.")
3130
def test_mixedbread_embedder(embedder_file: Path):
3231
api_key = get_api_key()
3332
embedder_config = EmbedderConfig(embedding_provider="mixedbread-ai", embedding_api_key=api_key)
@@ -41,7 +40,6 @@ def test_mixedbread_embedder(embedder_file: Path):
4140

4241

4342
@requires_env(API_KEY)
44-
@pytest.mark.skip(reason="need to migrate to new mixedbread sdk.")
4543
def test_raw_mixedbread_embedder(embedder_file: Path):
4644
api_key = get_api_key()
4745
embedder = MixedbreadAIEmbeddingEncoder(
@@ -60,7 +58,6 @@ def test_raw_mixedbread_embedder(embedder_file: Path):
6058

6159
@requires_env(API_KEY)
6260
@pytest.mark.asyncio
63-
@pytest.mark.skip(reason="need to migrate to new mixedbread sdk.")
6461
async def test_raw_async_mixedbread_embedder(embedder_file: Path):
6562
api_key = get_api_key()
6663
embedder = AsyncMixedbreadAIEmbeddingEncoder(

test/unit/embed/test_mixedbreadai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@
77
def test_embed_documents_does_not_break_element_to_dict(mocker):
88
mock_client = mocker.MagicMock()
99

10-
def mock_embeddings(
10+
def mock_embed(
1111
model,
12+
input,
1213
normalized,
1314
encoding_format,
14-
truncation_strategy,
15-
request_options,
16-
input,
15+
extra_headers,
16+
timeout,
1717
):
1818
mock_response = mocker.MagicMock()
1919
mock_response.data = [mocker.MagicMock(embedding=[i, i + 1]) for i in range(len(input))]
2020
return mock_response
2121

22-
mock_client.embeddings.side_effect = mock_embeddings
22+
mock_client.embed.side_effect = mock_embed
2323

2424
# Mock get_client to return our mock_client
2525
mocker.patch.object(MixedbreadAIEmbeddingConfig, "get_client", return_value=mock_client)
26-
mocker.patch.object(MixedbreadAIEmbeddingEncoder, "get_request_options", return_value={})
26+
mocker.patch.object(MixedbreadAIEmbeddingEncoder, "get_client", return_value=mock_client)
2727

2828
encoder = MixedbreadAIEmbeddingEncoder(
2929
config=MixedbreadAIEmbeddingConfig(

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.22" # pragma: no cover
1+
__version__ = "1.0.22-dev1" # pragma: no cover

unstructured_ingest/embed/mixedbreadai.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020

2121
if TYPE_CHECKING:
22-
from mixedbread_ai.client import AsyncMixedbreadAI, MixedbreadAI
23-
from mixedbread_ai.core import RequestOptions
22+
from mixedbread import AsyncMixedbread, Mixedbread
2423

2524

2625
class MixedbreadAIEmbeddingConfig(EmbeddingConfig):
@@ -44,31 +43,33 @@ class MixedbreadAIEmbeddingConfig(EmbeddingConfig):
4443
)
4544

4645
@requires_dependencies(
47-
["mixedbread_ai"],
48-
extras="mixedbreadai",
46+
["mixedbread"],
47+
extras="embed-mixedbreadai",
4948
)
50-
def get_client(self) -> "MixedbreadAI":
49+
def get_client(self) -> "Mixedbread":
5150
"""
5251
Create the Mixedbread AI client.
5352
5453
Returns:
55-
MixedbreadAI: Initialized client.
54+
Mixedbread: Initialized client.
5655
"""
57-
from mixedbread_ai.client import MixedbreadAI
56+
from mixedbread import Mixedbread
5857

59-
return MixedbreadAI(
58+
return Mixedbread(
6059
api_key=self.api_key.get_secret_value(),
60+
max_retries=MAX_RETRIES,
6161
)
6262

6363
@requires_dependencies(
64-
["mixedbread_ai"],
65-
extras="mixedbreadai",
64+
["mixedbread"],
65+
extras="embed-mixedbreadai",
6666
)
67-
def get_async_client(self) -> "AsyncMixedbreadAI":
68-
from mixedbread_ai.client import AsyncMixedbreadAI
67+
def get_async_client(self) -> "AsyncMixedbread":
68+
from mixedbread import AsyncMixedbread
6969

70-
return AsyncMixedbreadAI(
70+
return AsyncMixedbread(
7171
api_key=self.api_key.get_secret_value(),
72+
max_retries=MAX_RETRIES,
7273
)
7374

7475

@@ -88,29 +89,20 @@ def get_exemplary_embedding(self) -> list[float]:
8889
return self.embed_query(query="Q")
8990

9091
@requires_dependencies(
91-
["mixedbread_ai"],
92+
["mixedbread"],
9293
extras="embed-mixedbreadai",
9394
)
94-
def get_request_options(self) -> "RequestOptions":
95-
from mixedbread_ai.core import RequestOptions
96-
97-
return RequestOptions(
98-
max_retries=MAX_RETRIES,
99-
timeout_in_seconds=TIMEOUT,
100-
additional_headers={"User-Agent": USER_AGENT},
101-
)
102-
103-
def get_client(self) -> "MixedbreadAI":
95+
def get_client(self) -> "Mixedbread":
10496
return self.config.get_client()
10597

106-
def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
107-
response = client.embeddings(
98+
def embed_batch(self, client: "Mixedbread", batch: list[str]) -> list[list[float]]:
99+
response = client.embed(
108100
model=self.config.embedder_model_name,
101+
input=batch,
109102
normalized=True,
110103
encoding_format=ENCODING_FORMAT,
111-
truncation_strategy=TRUNCATION_STRATEGY,
112-
request_options=self.get_request_options(),
113-
input=batch,
104+
extra_headers={"User-Agent": USER_AGENT},
105+
timeout=TIMEOUT,
114106
)
115107
return [datum.embedding for datum in response.data]
116108

@@ -124,28 +116,19 @@ async def get_exemplary_embedding(self) -> list[float]:
124116
return await self.embed_query(query="Q")
125117

126118
@requires_dependencies(
127-
["mixedbread_ai"],
119+
["mixedbread"],
128120
extras="embed-mixedbreadai",
129121
)
130-
def get_request_options(self) -> "RequestOptions":
131-
from mixedbread_ai.core import RequestOptions
132-
133-
return RequestOptions(
134-
max_retries=MAX_RETRIES,
135-
timeout_in_seconds=TIMEOUT,
136-
additional_headers={"User-Agent": USER_AGENT},
137-
)
138-
139-
def get_client(self) -> "AsyncMixedbreadAI":
122+
def get_client(self) -> "AsyncMixedbread":
140123
return self.config.get_async_client()
141124

142-
async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
143-
response = await client.embeddings(
125+
async def embed_batch(self, client: "AsyncMixedbread", batch: list[str]) -> list[list[float]]:
126+
response = await client.embed(
144127
model=self.config.embedder_model_name,
128+
input=batch,
145129
normalized=True,
146130
encoding_format=ENCODING_FORMAT,
147-
truncation_strategy=TRUNCATION_STRATEGY,
148-
request_options=self.get_request_options(),
149-
input=batch,
131+
extra_headers={"User-Agent": USER_AGENT},
132+
timeout=TIMEOUT,
150133
)
151134
return [datum.embedding for datum in response.data]

0 commit comments

Comments
 (0)