diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e174d1db..ea32b244c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ * **Migrate Jira Source connector from V1 to V2** * **Add Jira Source connector integration and unit tests** +* **Added option for custom OpenAI baseurl in EmbedderConfig** ## 0.5.9 diff --git a/unstructured_ingest/embed/openai.py b/unstructured_ingest/embed/openai.py index bd745396b..9d10371cc 100644 --- a/unstructured_ingest/embed/openai.py +++ b/unstructured_ingest/embed/openai.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from pydantic import Field, SecretStr @@ -26,6 +26,7 @@ class OpenAIEmbeddingConfig(EmbeddingConfig): api_key: SecretStr embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name") + base_url: Optional[str] = Field(default=None) def wrap_error(self, e: Exception) -> Exception: if is_internal_error(e=e): @@ -57,13 +58,13 @@ def wrap_error(self, e: Exception) -> Exception: def get_client(self) -> "OpenAI": from openai import OpenAI - return OpenAI(api_key=self.api_key.get_secret_value()) + return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url) @requires_dependencies(["openai"], extras="openai") def get_async_client(self) -> "AsyncOpenAI": from openai import AsyncOpenAI - return AsyncOpenAI(api_key=self.api_key.get_secret_value()) + return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url) @dataclass diff --git a/unstructured_ingest/v2/processes/embedder.py b/unstructured_ingest/v2/processes/embedder.py index fcfd2ed81..a7b38387a 100644 --- a/unstructured_ingest/v2/processes/embedder.py +++ b/unstructured_ingest/v2/processes/embedder.py @@ -52,6 +52,11 @@ class EmbedderConfig(BaseModel): embedding_azure_api_version: Optional[str] = Field( description="Azure API version", default=None ) + embedding_openai_endpoint: Optional[str] = Field( + default=None, + description="Your custom OpenAI base url, " + "e.g. `https://custom-openai-deployment.com/`", + ) def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder": from unstructured_ingest.embed.huggingface import ( @@ -66,7 +71,16 @@ def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEnco def get_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder": from unstructured_ingest.embed.openai import OpenAIEmbeddingConfig, OpenAIEmbeddingEncoder - return OpenAIEmbeddingEncoder(config=OpenAIEmbeddingConfig.model_validate(embedding_kwargs)) + config_kwargs = { + "api_key": self.embedding_api_key, + "base_url": self.embedding_openai_endpoint, + } + if model_name := self.embedding_model_name: + config_kwargs["model_name"] = model_name + + return OpenAIEmbeddingEncoder( + config=OpenAIEmbeddingConfig.model_validate(config_kwargs) + ) def get_azure_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder": from unstructured_ingest.embed.azure_openai import (