Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions app/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ async def setup_clients():
)
AZURE_OPENAI_EMB_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT") if OPENAI_HOST.startswith("azure") else None
AZURE_OPENAI_CUSTOM_URL = os.getenv("AZURE_OPENAI_CUSTOM_URL")
# https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01"
AZURE_VISION_ENDPOINT = os.getenv("AZURE_VISION_ENDPOINT", "")
# Used only with non-Azure OpenAI deployments
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
Expand Down Expand Up @@ -547,6 +549,7 @@ async def setup_clients():
openai_custom_url=AZURE_OPENAI_CUSTOM_URL,
openai_deployment=AZURE_OPENAI_EMB_DEPLOYMENT,
openai_dimensions=OPENAI_EMB_DIMENSIONS,
openai_api_version=AZURE_OPENAI_API_VERSION,
openai_key=clean_key_if_exists(OPENAI_API_KEY),
openai_org=OPENAI_ORGANIZATION,
disable_vectors=os.getenv("USE_VECTORS", "").lower() == "false",
Expand All @@ -573,7 +576,6 @@ async def setup_clients():
current_app.config[CONFIG_CREDENTIAL] = azure_credential

if OPENAI_HOST.startswith("azure"):
api_version = os.getenv("AZURE_OPENAI_API_VERSION") or "2024-03-01-preview"
if OPENAI_HOST == "azure_custom":
current_app.logger.info("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client")
if not AZURE_OPENAI_CUSTOM_URL:
Expand All @@ -586,12 +588,14 @@ async def setup_clients():
endpoint = f"https://{AZURE_OPENAI_SERVICE}.openai.azure.com"
if api_key := os.getenv("AZURE_OPENAI_API_KEY_OVERRIDE"):
current_app.logger.info("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client")
openai_client = AsyncAzureOpenAI(api_version=api_version, azure_endpoint=endpoint, api_key=api_key)
openai_client = AsyncAzureOpenAI(
api_version=AZURE_OPENAI_API_VERSION, azure_endpoint=endpoint, api_key=api_key
)
else:
current_app.logger.info("Using Azure credential (passwordless authentication) for Azure OpenAI client")
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
openai_client = AsyncAzureOpenAI(
api_version=api_version,
api_version=AZURE_OPENAI_API_VERSION,
azure_endpoint=endpoint,
azure_ad_token_provider=token_provider,
)
Expand Down
4 changes: 4 additions & 0 deletions app/backend/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def setup_embeddings_service(
openai_custom_url: Union[str, None],
openai_deployment: Union[str, None],
openai_dimensions: int,
openai_api_version: str,
openai_key: Union[str, None],
openai_org: Union[str, None],
disable_vectors: bool = False,
Expand All @@ -134,6 +135,7 @@ def setup_embeddings_service(
open_ai_deployment=openai_deployment,
open_ai_model_name=openai_model_name,
open_ai_dimensions=openai_dimensions,
open_ai_api_version=openai_api_version,
credential=azure_open_ai_credential,
disable_batch=disable_batch_vectors,
)
Expand Down Expand Up @@ -366,6 +368,8 @@ async def main(strategy: Strategy, setup_index: bool = True):
openai_service=os.getenv("AZURE_OPENAI_SERVICE"),
openai_custom_url=os.getenv("AZURE_OPENAI_CUSTOM_URL"),
openai_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"),
# https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01",
openai_dimensions=openai_dimensions,
openai_key=clean_key_if_exists(openai_key),
openai_org=os.getenv("OPENAI_ORGANIZATION"),
Expand Down
4 changes: 3 additions & 1 deletion app/backend/prepdocslib/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
open_ai_deployment: Union[str, None],
open_ai_model_name: str,
open_ai_dimensions: int,
open_ai_api_version: str,
credential: Union[AsyncTokenCredential, AzureKeyCredential],
open_ai_custom_url: Union[str, None] = None,
disable_batch: bool = False,
Expand All @@ -176,6 +177,7 @@ def __init__(
else:
raise ValueError("Either open_ai_service or open_ai_custom_url must be provided")
self.open_ai_deployment = open_ai_deployment
self.open_ai_api_version = open_ai_api_version
self.credential = credential

async def create_client(self) -> AsyncOpenAI:
Expand All @@ -196,7 +198,7 @@ class AuthArgs(TypedDict, total=False):
return AsyncAzureOpenAI(
azure_endpoint=self.open_ai_endpoint,
azure_deployment=self.open_ai_deployment,
api_version="2023-05-15",
api_version=self.open_ai_api_version,
**auth_args,
)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def mock_create_client(*args, **kwargs):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=False,
)
Expand All @@ -79,6 +80,7 @@ async def mock_create_client(*args, **kwargs):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=True,
)
Expand Down Expand Up @@ -149,6 +151,7 @@ async def test_compute_embedding_ratelimiterror_batch(monkeypatch, caplog):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=False,
)
Expand All @@ -167,6 +170,7 @@ async def test_compute_embedding_ratelimiterror_single(monkeypatch, caplog):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=True,
)
Expand All @@ -193,6 +197,7 @@ async def test_compute_embedding_autherror(monkeypatch, capsys):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=False,
)
Expand All @@ -205,6 +210,7 @@ async def test_compute_embedding_autherror(monkeypatch, capsys):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=MockAzureCredential(),
disable_batch=True,
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def mock_upload_documents(self, documents):
open_ai_deployment="x",
open_ai_model_name=MOCK_EMBEDDING_MODEL_NAME,
open_ai_dimensions=MOCK_EMBEDDING_DIMENSIONS,
open_ai_api_version="test-api-version",
credential=AzureKeyCredential("test"),
disable_batch=True,
)
Expand Down