|
5 | 5 | from chromadb.utils.embedding_functions.openai_embedding_function import ( |
6 | 6 | OpenAIEmbeddingFunction, |
7 | 7 | ) |
| 8 | +from chromadb.errors import InvalidArgumentError |
8 | 9 |
|
9 | 10 |
|
10 | 11 | def test_with_embedding_dimensions() -> None: |
@@ -36,3 +37,42 @@ def test_with_incorrect_api_key() -> None: |
36 | 37 | ef = OpenAIEmbeddingFunction(api_key="incorrect_api_key", dimensions=64) |
37 | 38 | with pytest.raises(Exception, match="Incorrect API key provided"): |
38 | 39 | ef(["hello world"]) |
| 40 | + |
| 41 | + |
| 42 | +def test_azure_requires_deployment_id() -> None: |
| 43 | + """Azure OpenAI should require deployment_id parameter.""" |
| 44 | + pytest.importorskip("openai", reason="openai not installed") |
| 45 | + with pytest.raises(InvalidArgumentError, match="deployment_id must be specified"): |
| 46 | + OpenAIEmbeddingFunction( |
| 47 | + api_key="test_key", |
| 48 | + api_type="azure", |
| 49 | + api_base="https://example.openai.azure.com", |
| 50 | + api_version="2023-05-15", |
| 51 | + # Missing deployment_id should raise |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +def test_azure_requires_api_version() -> None: |
| 56 | + """Azure OpenAI should require api_version parameter.""" |
| 57 | + pytest.importorskip("openai", reason="openai not installed") |
| 58 | + with pytest.raises(InvalidArgumentError, match="api_version must be specified"): |
| 59 | + OpenAIEmbeddingFunction( |
| 60 | + api_key="test_key", |
| 61 | + api_type="azure", |
| 62 | + api_base="https://example.openai.azure.com", |
| 63 | + deployment_id="my-deployment", |
| 64 | + # Missing api_version should raise |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +def test_azure_requires_api_base() -> None: |
| 69 | + """Azure OpenAI should require api_base parameter.""" |
| 70 | + pytest.importorskip("openai", reason="openai not installed") |
| 71 | + with pytest.raises(InvalidArgumentError, match="api_base must be specified"): |
| 72 | + OpenAIEmbeddingFunction( |
| 73 | + api_key="test_key", |
| 74 | + api_type="azure", |
| 75 | + api_version="2023-05-15", |
| 76 | + deployment_id="my-deployment", |
| 77 | + # Missing api_base should raise |
| 78 | + ) |
0 commit comments