Skip to content

Commit dc00929

Browse files
authored
feat/azure openai embedder (#265)
* add azure openai embedder * bump changelog * add int test * directly map config in embedder process * set default for api version
1 parent d58e37f commit dc00929

File tree

5 files changed

+124
-3
lines changed

5 files changed

+124
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
## 0.3.4-dev1
1+
## 0.3.4
22

3-
### Features
3+
### Enhancements
44

5+
* **Add azure openai embedder**
56
* **Add `collection_id` field to Couchbase `downloader_config`**
67

78
## 0.3.3
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
import os
3+
from dataclasses import dataclass
4+
from pathlib import Path
5+
6+
from test.integration.embedders.utils import validate_embedding_output, validate_raw_embedder
7+
from test.integration.utils import requires_env
8+
from unstructured_ingest.embed.azure_openai import (
9+
AzureOpenAIEmbeddingConfig,
10+
AzureOpenAIEmbeddingEncoder,
11+
)
12+
from unstructured_ingest.v2.processes.embedder import Embedder, EmbedderConfig
13+
14+
API_KEY = "AZURE_OPENAI_API_KEY"
15+
ENDPOINT = "AZURE_OPENAI_ENDPOINT"
16+
17+
18+
@dataclass(frozen=True)
19+
class AzureData:
20+
api_key: str
21+
endpoint: str
22+
23+
24+
def get_azure_data() -> AzureData:
25+
api_key = os.getenv(API_KEY, None)
26+
assert api_key
27+
endpoint = os.getenv(ENDPOINT, None)
28+
assert endpoint
29+
return AzureData(api_key, endpoint)
30+
31+
32+
@requires_env(API_KEY, ENDPOINT)
33+
def test_azure_openai_embedder(embedder_file: Path):
34+
azure_data = get_azure_data()
35+
embedder_config = EmbedderConfig(
36+
embedding_provider="azure-openai",
37+
embedding_api_key=azure_data.api_key,
38+
embedding_azure_endpoint=azure_data.endpoint,
39+
)
40+
embedder = Embedder(config=embedder_config)
41+
results = embedder.run(elements_filepath=embedder_file)
42+
assert results
43+
with embedder_file.open("r") as f:
44+
original_elements = json.load(f)
45+
validate_embedding_output(original_elements=original_elements, output_elements=results)
46+
47+
48+
@requires_env(API_KEY, ENDPOINT)
49+
def test_raw_azure_openai_embedder(embedder_file: Path):
50+
azure_data = get_azure_data()
51+
embedder = AzureOpenAIEmbeddingEncoder(
52+
config=AzureOpenAIEmbeddingConfig(
53+
api_key=azure_data.api_key,
54+
azure_endpoint=azure_data.endpoint,
55+
)
56+
)
57+
validate_raw_embedder(
58+
embedder=embedder, embedder_file=embedder_file, expected_dimensions=(1536,)
59+
)

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.4-dev1" # pragma: no cover
1+
__version__ = "0.3.4" # pragma: no cover
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from typing import TYPE_CHECKING
3+
4+
from pydantic import Field
5+
6+
from unstructured_ingest.embed.openai import OpenAIEmbeddingConfig, OpenAIEmbeddingEncoder
7+
from unstructured_ingest.utils.dep_check import requires_dependencies
8+
9+
if TYPE_CHECKING:
10+
from openai import AzureOpenAI
11+
12+
13+
class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
14+
api_version: str = Field(description="Azure API version", default="2024-06-01")
15+
azure_endpoint: str
16+
embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
17+
18+
@requires_dependencies(["openai"], extras="openai")
19+
def get_client(self) -> "AzureOpenAI":
20+
from openai import AzureOpenAI
21+
22+
return AzureOpenAI(
23+
api_key=self.api_key.get_secret_value(),
24+
api_version=self.api_version,
25+
azure_endpoint=self.azure_endpoint,
26+
)
27+
28+
29+
@dataclass
30+
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
31+
config: AzureOpenAIEmbeddingConfig

unstructured_ingest/v2/processes/embedder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class EmbedderConfig(BaseModel):
1616
embedding_provider: Optional[
1717
Literal[
1818
"openai",
19+
"azure-openai",
1920
"huggingface",
2021
"aws-bedrock",
2122
"vertexai",
@@ -43,6 +44,14 @@ class EmbedderConfig(BaseModel):
4344
embedding_aws_region: Optional[str] = Field(
4445
default="us-west-2", description="AWS region used for AWS-based embedders, such as bedrock"
4546
)
47+
embedding_azure_endpoint: Optional[str] = Field(
48+
default=None,
49+
description="Your Azure endpoint, including the resource, "
50+
"e.g. `https://example-resource.azure.openai.com/`",
51+
)
52+
embedding_azure_api_version: Optional[str] = Field(
53+
description="Azure API version", default=None
54+
)
4655

4756
def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
4857
from unstructured_ingest.embed.huggingface import (
@@ -59,6 +68,25 @@ def get_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
5968

6069
return OpenAIEmbeddingEncoder(config=OpenAIEmbeddingConfig.model_validate(embedding_kwargs))
6170

71+
def get_azure_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
72+
from unstructured_ingest.embed.azure_openai import (
73+
AzureOpenAIEmbeddingConfig,
74+
AzureOpenAIEmbeddingEncoder,
75+
)
76+
77+
config_kwargs = {
78+
"api_key": self.embedding_api_key,
79+
"azure_endpoint": self.embedding_azure_endpoint,
80+
}
81+
if api_version := self.embedding_azure_api_version:
82+
config_kwargs["api_version"] = api_version
83+
if model_name := self.embedding_model_name:
84+
config_kwargs["model_name"] = model_name
85+
86+
return AzureOpenAIEmbeddingEncoder(
87+
config=AzureOpenAIEmbeddingConfig.model_validate(config_kwargs)
88+
)
89+
6290
def get_octoai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
6391
from unstructured_ingest.embed.octoai import OctoAiEmbeddingConfig, OctoAIEmbeddingEncoder
6492

@@ -146,6 +174,8 @@ def get_embedder(self) -> "BaseEmbeddingEncoder":
146174
return self.get_mixedbread_embedder(embedding_kwargs=kwargs)
147175
if self.embedding_provider == "togetherai":
148176
return self.get_togetherai_embedder(embedding_kwargs=kwargs)
177+
if self.embedding_provider == "azure-openai":
178+
return self.get_azure_openai_embedder(embedding_kwargs=kwargs)
149179

150180
raise ValueError(f"{self.embedding_provider} not a recognized encoder")
151181

0 commit comments

Comments
 (0)