Skip to content

Commit f862081

Browse files
authored
feat/add precheck support for embedders (#473)
* add precheck support for embedders * update tests * fix bedrock * fix syntax
1 parent ede5e30 commit f862081

22 files changed

+214
-96
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 1.0.14
2+
3+
### Enhancements
4+
5+
* **Add precheck support for embedders that support listing models**
6+
17
## 1.0.13
28

39
### Fixes

test/integration/connectors/test_google_drive.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def google_drive_empty_folder(google_drive_connection_config):
5353
from googleapiclient.discovery import build
5454

5555
access_config = google_drive_connection_config.access_config.get_secret_value()
56-
creds = service_account.Credentials.from_service_account_info(
57-
access_config.service_account_key
58-
)
56+
creds = service_account.Credentials.from_service_account_info(access_config.service_account_key)
5957
service = build("drive", "v3", credentials=creds)
6058

6159
# Create an empty folder.
@@ -120,23 +118,16 @@ def test_google_drive_precheck_invalid_parameter(google_drive_connection_config)
120118
access_config=google_drive_connection_config.access_config,
121119
)
122120
index_config = GoogleDriveIndexerConfig(recursive=True)
123-
indexer = GoogleDriveIndexer(
124-
connection_config=connection_config, index_config=index_config
125-
)
121+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
126122
with pytest.raises(SourceConnectionError) as excinfo:
127123
indexer.precheck()
128-
assert (
129-
"invalid" in str(excinfo.value).lower()
130-
or "not found" in str(excinfo.value).lower()
131-
)
124+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
132125

133126

134127
# Precheck fails due to lack of permission (simulate via monkeypatching).
135128
@pytest.mark.tags("google-drive", "precheck")
136129
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
137-
def test_google_drive_precheck_no_permission(
138-
google_drive_connection_config, monkeypatch
139-
):
130+
def test_google_drive_precheck_no_permission(google_drive_connection_config, monkeypatch):
140131
index_config = GoogleDriveIndexerConfig(recursive=True)
141132
indexer = GoogleDriveIndexer(
142133
connection_config=google_drive_connection_config,
@@ -153,10 +144,7 @@ def fake_get_root_info(files_client, object_id):
153144
monkeypatch.setattr(indexer, "get_root_info", fake_get_root_info)
154145
with pytest.raises(SourceConnectionError) as excinfo:
155146
indexer.precheck()
156-
assert (
157-
"forbidden" in str(excinfo.value).lower()
158-
or "permission" in str(excinfo.value).lower()
159-
)
147+
assert "forbidden" in str(excinfo.value).lower() or "permission" in str(excinfo.value).lower()
160148

161149

162150
# Precheck fails when the folder is empty.
@@ -206,15 +194,10 @@ def test_google_drive_precheck_invalid_drive_id(google_drive_connection_config):
206194
access_config=google_drive_connection_config.access_config,
207195
)
208196
index_config = GoogleDriveIndexerConfig(recursive=True)
209-
indexer = GoogleDriveIndexer(
210-
connection_config=connection_config, index_config=index_config
211-
)
197+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
212198
with pytest.raises(SourceConnectionError) as excinfo:
213199
indexer.precheck()
214-
assert (
215-
"invalid" in str(excinfo.value).lower()
216-
or "not found" in str(excinfo.value).lower()
217-
)
200+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
218201

219202

220203
MIME_TYPES_TO_TEST = [
@@ -244,19 +227,15 @@ async def test_google_drive_export_by_type(expected_mime, temp_dir):
244227
index_config = GoogleDriveIndexerConfig(recursive=True)
245228
download_config = GoogleDriveDownloaderConfig(download_dir=temp_dir)
246229

247-
indexer = GoogleDriveIndexer(
248-
connection_config=connection_config, index_config=index_config
249-
)
230+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
250231
downloader = GoogleDriveDownloader(
251232
connection_config=connection_config, download_config=download_config
252233
)
253234

254235
file_datas = list(indexer.run())
255236

256237
# Filter only the target MIME type
257-
target_files = [
258-
f for f in file_datas if f.additional_metadata.get("mimeType") == expected_mime
259-
]
238+
target_files = [f for f in file_datas if f.additional_metadata.get("mimeType") == expected_mime]
260239
assert target_files, f"No files found with MIME type: {expected_mime}"
261240

262241
for file_data in target_files:

test/integration/embedders/test_azure_openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_azure_openai_embedder(embedder_file: Path):
3838
embedding_azure_endpoint=azure_data.endpoint,
3939
)
4040
embedder = Embedder(config=embedder_config)
41+
embedder.precheck()
4142
results = embedder.run(elements_filepath=embedder_file)
4243
assert results
4344
with embedder_file.open("r") as f:
@@ -54,4 +55,5 @@ def test_raw_azure_openai_embedder(embedder_file: Path):
5455
azure_endpoint=azure_data.endpoint,
5556
)
5657
)
58+
embedder.precheck()
5759
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=1536)

test/integration/embedders/test_bedrock.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_bedrock_embedder(embedder_file: Path):
3636
embedding_aws_secret_access_key=aws_credentials["aws_secret_access_key"],
3737
)
3838
embedder = Embedder(config=embedder_config)
39+
embedder.precheck()
3940
results = embedder.run(elements_filepath=embedder_file)
4041
assert results
4142
with embedder_file.open("r") as f:
@@ -52,6 +53,7 @@ def test_raw_bedrock_embedder(embedder_file: Path):
5253
aws_secret_access_key=aws_credentials["aws_secret_access_key"],
5354
)
5455
)
56+
embedder.precheck()
5557
validate_raw_embedder(
5658
embedder=embedder,
5759
embedder_file=embedder_file,
@@ -82,7 +84,7 @@ def test_raw_bedrock_embedder_invalid_model(embedder_file: Path):
8284
)
8385
)
8486
with pytest.raises(UserError):
85-
embedder.get_exemplary_embedding()
87+
embedder.precheck()
8688

8789

8890
@requires_env("AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY")
@@ -95,6 +97,7 @@ async def test_raw_async_bedrock_embedder(embedder_file: Path):
9597
aws_secret_access_key=aws_credentials["aws_secret_access_key"],
9698
)
9799
)
100+
embedder.precheck()
98101
await validate_raw_embedder_async(
99102
embedder=embedder,
100103
embedder_file=embedder_file,

test/integration/embedders/test_huggingface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
def test_huggingface_embedder(embedder_file: Path):
1313
embedder_config = EmbedderConfig(embedding_provider="huggingface")
1414
embedder = Embedder(config=embedder_config)
15+
embedder.precheck()
1516
results = embedder.run(elements_filepath=embedder_file)
1617
assert results
1718
with embedder_file.open("r") as f:
@@ -21,4 +22,5 @@ def test_huggingface_embedder(embedder_file: Path):
2122

2223
def test_raw_hugginface_embedder(embedder_file: Path):
2324
embedder = HuggingFaceEmbeddingEncoder(config=HuggingFaceEmbeddingConfig())
25+
embedder.precheck()
2426
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=384)

test/integration/embedders/test_mixedbread.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_mixedbread_embedder(embedder_file: Path):
3131
api_key = get_api_key()
3232
embedder_config = EmbedderConfig(embedding_provider="mixedbread-ai", embedding_api_key=api_key)
3333
embedder = Embedder(config=embedder_config)
34+
embedder.precheck()
3435
results = embedder.run(elements_filepath=embedder_file)
3536
assert results
3637
with embedder_file.open("r") as f:
@@ -46,6 +47,7 @@ def test_raw_mixedbread_embedder(embedder_file: Path):
4647
api_key=api_key,
4748
)
4849
)
50+
embedder.precheck()
4951
validate_raw_embedder(
5052
embedder=embedder,
5153
embedder_file=embedder_file,
@@ -63,6 +65,7 @@ async def test_raw_async_mixedbread_embedder(embedder_file: Path):
6365
api_key=api_key,
6466
)
6567
)
68+
embedder.precheck()
6669
await validate_raw_embedder_async(
6770
embedder=embedder,
6871
embedder_file=embedder_file,

test/integration/embedders/test_octoai.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
OctoAiEmbeddingConfig,
1616
OctoAIEmbeddingEncoder,
1717
)
18-
from unstructured_ingest.errors_v2 import UserAuthError
18+
from unstructured_ingest.errors_v2 import UserAuthError, UserError
1919
from unstructured_ingest.processes.embedder import Embedder, EmbedderConfig
2020

2121
API_KEY = "OCTOAI_API_KEY"
@@ -32,6 +32,7 @@ def test_octoai_embedder(embedder_file: Path):
3232
api_key = get_api_key()
3333
embedder_config = EmbedderConfig(embedding_provider="octoai", embedding_api_key=api_key)
3434
embedder = Embedder(config=embedder_config)
35+
embedder.precheck()
3536
results = embedder.run(elements_filepath=embedder_file)
3637
assert results
3738
with embedder_file.open("r") as f:
@@ -47,6 +48,7 @@ def test_raw_octoai_embedder(embedder_file: Path):
4748
api_key=api_key,
4849
)
4950
)
51+
embedder.precheck()
5052
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=1024)
5153

5254

@@ -58,7 +60,7 @@ def test_raw_octoai_embedder_invalid_credentials():
5860
)
5961
)
6062
with pytest.raises(UserAuthError):
61-
embedder.get_exemplary_embedding()
63+
embedder.precheck()
6264

6365

6466
@requires_env(API_KEY)
@@ -70,6 +72,17 @@ async def test_raw_async_octoai_embedder(embedder_file: Path):
7072
api_key=api_key,
7173
)
7274
)
75+
embedder.precheck()
7376
await validate_raw_embedder_async(
7477
embedder=embedder, embedder_file=embedder_file, expected_dimension=1024
7578
)
79+
80+
81+
@requires_env(API_KEY)
82+
def test_octoai_wrong_model():
83+
api_key = get_api_key()
84+
embedder = OctoAIEmbeddingEncoder(
85+
config=OctoAiEmbeddingConfig(api_key=api_key, model_name="fake_model_name")
86+
)
87+
with pytest.raises(UserError):
88+
embedder.precheck()

test/integration/embedders/test_openai.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
OpenAIEmbeddingConfig,
1616
OpenAIEmbeddingEncoder,
1717
)
18-
from unstructured_ingest.errors_v2 import UserAuthError
18+
from unstructured_ingest.errors_v2 import UserAuthError, UserError
1919
from unstructured_ingest.processes.embedder import Embedder, EmbedderConfig
2020

2121
API_KEY = "OPENAI_API_KEY"
@@ -32,6 +32,7 @@ def test_openai_embedder(embedder_file: Path):
3232
api_key = get_api_key()
3333
embedder_config = EmbedderConfig(embedding_provider="openai", embedding_api_key=api_key)
3434
embedder = Embedder(config=embedder_config)
35+
embedder.precheck()
3536
results = embedder.run(elements_filepath=embedder_file)
3637
assert results
3738
with embedder_file.open("r") as f:
@@ -47,6 +48,7 @@ def test_raw_openai_embedder(embedder_file: Path):
4748
api_key=api_key,
4849
)
4950
)
51+
embedder.precheck()
5052
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=1536)
5153

5254

@@ -69,6 +71,17 @@ async def test_raw_async_openai_embedder(embedder_file: Path):
6971
api_key=api_key,
7072
)
7173
)
74+
embedder.precheck()
7275
await validate_raw_embedder_async(
7376
embedder=embedder, embedder_file=embedder_file, expected_dimension=1536
7477
)
78+
79+
80+
@requires_env(API_KEY)
81+
def test_openai_wrong_model():
82+
api_key = get_api_key()
83+
embedder = OpenAIEmbeddingEncoder(
84+
config=OpenAIEmbeddingConfig(api_key=api_key, model_name="fake_model_name")
85+
)
86+
with pytest.raises(UserError):
87+
embedder.precheck()

test/integration/embedders/test_togetherai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_togetherai_embedder(embedder_file: Path):
3232
api_key = get_api_key()
3333
embedder_config = EmbedderConfig(embedding_provider="togetherai", embedding_api_key=api_key)
3434
embedder = Embedder(config=embedder_config)
35+
embedder.precheck()
3536
results = embedder.run(elements_filepath=embedder_file)
3637
assert results
3738
with embedder_file.open("r") as f:
@@ -43,6 +44,7 @@ def test_togetherai_embedder(embedder_file: Path):
4344
def test_raw_togetherai_embedder(embedder_file: Path):
4445
api_key = get_api_key()
4546
embedder = TogetherAIEmbeddingEncoder(config=TogetherAIEmbeddingConfig(api_key=api_key))
47+
embedder.precheck()
4648
validate_raw_embedder(
4749
embedder=embedder,
4850
embedder_file=embedder_file,
@@ -63,6 +65,7 @@ def test_raw_togetherai_embedder_invalid_credentials():
6365
async def test_raw_async_togetherai_embedder(embedder_file: Path):
6466
api_key = get_api_key()
6567
embedder = AsyncTogetherAIEmbeddingEncoder(config=TogetherAIEmbeddingConfig(api_key=api_key))
68+
embedder.precheck()
6669
await validate_raw_embedder_async(
6770
embedder=embedder,
6871
embedder_file=embedder_file,

test/integration/embedders/test_vertexai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_vertexai_embedder(embedder_file: Path):
3131
api_key = get_api_key()
3232
embedder_config = EmbedderConfig(embedding_provider="vertexai", embedding_api_key=api_key)
3333
embedder = Embedder(config=embedder_config)
34+
embedder.precheck()
3435
results = embedder.run(elements_filepath=embedder_file)
3536
assert results
3637
with embedder_file.open("r") as f:
@@ -46,6 +47,7 @@ def test_raw_vertexai_embedder(embedder_file: Path):
4647
api_key=api_key,
4748
)
4849
)
50+
embedder.precheck()
4951
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=768)
5052

5153

@@ -58,6 +60,7 @@ async def test_raw_async_vertexai_embedder(embedder_file: Path):
5860
api_key=api_key,
5961
)
6062
)
63+
embedder.precheck()
6164
await validate_raw_embedder_async(
6265
embedder=embedder, embedder_file=embedder_file, expected_dimension=768
6366
)

0 commit comments

Comments
 (0)