diff --git a/src/pyproject.toml b/src/pyproject.toml
index 228f148..6278222 100644
--- a/src/pyproject.toml
+++ b/src/pyproject.toml
@@ -13,8 +13,8 @@ dependencies = [
"httptools",
# Used by uvicorn for reload functionality
"watchfiles",
- "azure-ai-inference",
"azure-identity",
+ "openai",
"aiohttp",
"python-dotenv",
"pyyaml"
diff --git a/src/quartapp/chat.py b/src/quartapp/chat.py
index 6e3dd3c..bb2d62e 100644
--- a/src/quartapp/chat.py
+++ b/src/quartapp/chat.py
@@ -1,13 +1,8 @@
import json
import os
-from azure.ai.inference.aio import ChatCompletionsClient
-from azure.ai.inference.models import SystemMessage
-from azure.identity.aio import (
- AzureDeveloperCliCredential,
- ChainedTokenCredential,
- ManagedIdentityCredential,
-)
+from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential, get_bearer_token_provider
+from openai import AsyncOpenAI
from quart import (
Blueprint,
Response,
@@ -22,38 +17,34 @@
@bp.before_app_serving
async def configure_openai():
- # Use ManagedIdentityCredential with the client_id for user-assigned managed identities
- user_assigned_managed_identity_credential = ManagedIdentityCredential(client_id=os.getenv("AZURE_CLIENT_ID"))
-
- # Use AzureDeveloperCliCredential with the current tenant.
- azure_dev_cli_credential = AzureDeveloperCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID"), process_timeout=60)
-
- # Create a ChainedTokenCredential with ManagedIdentityCredential and AzureDeveloperCliCredential
- # - ManagedIdentityCredential is used for deployment on Azure Container Apps
-
- # - AzureDeveloperCliCredential is used for local development
- # The order of the credentials is important, as the first valid token is used
- # For more information check out:
-
- # https://learn.microsoft.com/azure/developer/python/sdk/authentication/credential-chains?tabs=ctc#chainedtokencredential-overview
- azure_credential = ChainedTokenCredential(user_assigned_managed_identity_credential, azure_dev_cli_credential)
- current_app.logger.info("Using Azure OpenAI with credential")
-
- if not os.getenv("AZURE_INFERENCE_ENDPOINT"):
- raise ValueError("AZURE_INFERENCE_ENDPOINT is required for Azure OpenAI")
+ if os.getenv("RUNNING_IN_PRODUCTION"):
+ client_id = os.environ["AZURE_CLIENT_ID"]
+ current_app.logger.info("Using Azure OpenAI with managed identity credential for client ID: %s", client_id)
+ bp.azure_credential = ManagedIdentityCredential(client_id=client_id)
+ else:
+ tenant_id = os.environ["AZURE_TENANT_ID"]
+ current_app.logger.info("Using Azure OpenAI with Azure Developer CLI credential for tenant ID: %s", tenant_id)
+ bp.azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id)
+
+ # Get the token provider for Azure OpenAI based on the selected Azure credential
+ bp.openai_token_provider = get_bearer_token_provider(
+ bp.azure_credential, "https://cognitiveservices.azure.com/.default"
+ )
# Create the Asynchronous Azure OpenAI client
- bp.ai_client = ChatCompletionsClient(
- endpoint=os.environ["AZURE_INFERENCE_ENDPOINT"],
- credential=azure_credential,
- credential_scopes=["https://cognitiveservices.azure.com/.default"],
- model="DeepSeek-R1",
+ bp.openai_client = AsyncOpenAI(
+ base_url=os.environ["AZURE_INFERENCE_ENDPOINT"],
+ api_key=await bp.openai_token_provider(),
+ default_query={"api-version": "2024-05-01-preview"},
)
+ # Set the model name to the Azure OpenAI model deployment name
+ bp.openai_model = os.getenv("AZURE_DEEPSEEK_DEPLOYMENT")
+
@bp.after_app_serving
async def shutdown_openai():
- await bp.ai_client.close()
+ await bp.openai_client.close()
@bp.get("/")
@@ -69,15 +60,20 @@ async def chat_handler():
async def response_stream():
# This sends all messages, so API request may exceed token limits
all_messages = [
- SystemMessage(content="You are a helpful assistant."),
+ {"role": "system", "content": "You are a helpful assistant."},
] + request_messages
- client: ChatCompletionsClient = bp.ai_client
- result = await client.complete(messages=all_messages, max_tokens=2048, stream=True)
+ bp.openai_client.api_key = await bp.openai_token_provider()
+ chat_coroutine = bp.openai_client.chat.completions.create(
+ # Azure Open AI takes the deployment name as the model name
+ model=bp.openai_model,
+ messages=all_messages,
+ stream=True,
+ )
try:
is_thinking = False
- async for update in result:
+ async for update in await chat_coroutine:
if update.choices:
content = update.choices[0].delta.content
if content == "":
@@ -103,4 +99,4 @@ async def response_stream():
current_app.logger.error(e)
yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n"
- return Response(response_stream(), mimetype="application/json")
+ return Response(response_stream())
diff --git a/src/requirements.txt b/src/requirements.txt
index c358a96..3421755 100644
--- a/src/requirements.txt
+++ b/src/requirements.txt
@@ -1,5 +1,5 @@
#
-# This file is autogenerated by pip-compile with Python 3.11
+# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements.txt pyproject.toml
@@ -12,16 +12,17 @@ aiohttp==3.10.11
# via quartapp (pyproject.toml)
aiosignal==1.3.1
# via aiohttp
+annotated-types==0.7.0
+ # via pydantic
anyio==4.6.0
- # via watchfiles
+ # via
+ # httpx
+ # openai
+ # watchfiles
attrs==24.2.0
# via aiohttp
-azure-ai-inference==1.0.0b8
- # via quartapp (pyproject.toml)
azure-core==1.31.0
- # via
- # azure-ai-inference
- # azure-identity
+ # via azure-identity
azure-identity==1.19.0
# via quartapp (pyproject.toml)
blinker==1.8.2
@@ -29,7 +30,10 @@ blinker==1.8.2
# flask
# quart
certifi==2024.8.30
- # via requests
+ # via
+ # httpcore
+ # httpx
+ # requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.0
@@ -44,6 +48,8 @@ cryptography==44.0.1
# azure-identity
# msal
# pyjwt
+distro==1.9.0
+ # via openai
flask==3.0.3
# via quart
frozenlist==1.4.1
@@ -54,6 +60,7 @@ gunicorn==23.0.0
# via quartapp (pyproject.toml)
h11==0.14.0
# via
+ # httpcore
# hypercorn
# uvicorn
# wsproto
@@ -61,8 +68,12 @@ h2==4.1.0
# via hypercorn
hpack==4.0.0
# via h2
+httpcore==1.0.7
+ # via httpx
httptools==0.6.4
# via quartapp (pyproject.toml)
+httpx==0.28.1
+ # via openai
hypercorn==0.17.3
# via quart
hyperframe==6.0.1
@@ -70,10 +81,9 @@ hyperframe==6.0.1
idna==3.10
# via
# anyio
+ # httpx
# requests
# yarl
-isodate==0.7.2
- # via azure-ai-inference
itsdangerous==2.2.0
# via
# flask
@@ -82,6 +92,8 @@ jinja2==3.1.5
# via
# flask
# quart
+jiter==0.9.0
+ # via openai
markupsafe==3.0.1
# via
# jinja2
@@ -97,6 +109,8 @@ multidict==6.1.0
# via
# aiohttp
# yarl
+openai==1.66.2
+ # via quartapp (pyproject.toml)
packaging==24.1
# via gunicorn
portalocker==2.10.1
@@ -107,8 +121,14 @@ propcache==0.2.0
# via yarl
pycparser==2.22
# via cffi
+pydantic==2.10.6
+ # via openai
+pydantic-core==2.27.2
+ # via pydantic
pyjwt[crypto]==2.9.0
- # via msal
+ # via
+ # msal
+ # pyjwt
python-dotenv==1.0.1
# via quartapp (pyproject.toml)
pyyaml==6.0.2
@@ -122,12 +142,18 @@ requests==2.32.3
six==1.16.0
# via azure-core
sniffio==1.3.1
- # via anyio
+ # via
+ # anyio
+ # openai
+tqdm==4.67.1
+ # via openai
typing-extensions==4.12.2
# via
- # azure-ai-inference
# azure-core
# azure-identity
+ # openai
+ # pydantic
+ # pydantic-core
urllib3==2.2.3
# via requests
uvicorn==0.32.0
diff --git a/tests/conftest.py b/tests/conftest.py
index c373604..c5cbac0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,4 @@
-import azure.ai.inference.models
+import openai
import pytest
import pytest_asyncio
@@ -13,54 +13,70 @@ class AsyncChatCompletionIterator:
def __init__(self, answer: str):
self.chunk_index = 0
self.chunks = [
- azure.ai.inference.models.StreamingChatCompletionsUpdate(
- id="test-123",
- created=1703462735,
- model="DeepSeek-R1",
- choices=[
- azure.ai.inference.models.StreamingChatChoiceUpdate(
- delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
- content=None, role="assistant"
- ),
- index=0,
- finish_reason=None,
- )
+ openai.types.chat.ChatCompletionChunk(
+ object="chat.completion.chunk",
+ choices=[],
+ id="",
+ created=0,
+ model="",
+ prompt_filter_results=[
+ {
+ "prompt_index": 0,
+ "content_filter_results": {
+ "hate": {"filtered": False, "severity": "safe"},
+ "self_harm": {"filtered": False, "severity": "safe"},
+ "sexual": {"filtered": False, "severity": "safe"},
+ "violence": {"filtered": False, "severity": "safe"},
+ },
+ }
],
- ),
+ )
]
answer_deltas = answer.split(" ")
for answer_index, answer_delta in enumerate(answer_deltas):
- # Completion chunks include whitespace, so we need to add it back in
- if answer_index > 0:
+ # Text completion chunks include whitespace, so we need to add it back in
+ if answer_index > 0 and answer_delta != "":
answer_delta = " " + answer_delta
self.chunks.append(
- azure.ai.inference.models.StreamingChatCompletionsUpdate(
+ openai.types.chat.ChatCompletionChunk(
id="test-123",
- created=1703462735,
- model="DeepSeek-R1",
+ object="chat.completion.chunk",
choices=[
- azure.ai.inference.models.StreamingChatChoiceUpdate(
- delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
- content=answer_delta, role=None
+ openai.types.chat.chat_completion_chunk.Choice(
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
+ role=None, content=answer_delta
),
- index=0,
finish_reason=None,
+ index=0,
+ logprobs=None,
+ # Only Azure includes content_filter_results
+ content_filter_results={
+ "hate": {"filtered": False, "severity": "safe"},
+ "self_harm": {"filtered": False, "severity": "safe"},
+ "sexual": {"filtered": False, "severity": "safe"},
+ "violence": {"filtered": False, "severity": "safe"},
+ },
)
],
+ created=1703462735,
+ model="DeepSeek-R1",
)
)
self.chunks.append(
- azure.ai.inference.models.StreamingChatCompletionsUpdate(
+ openai.types.chat.ChatCompletionChunk(
id="test-123",
- created=1703462735,
- model="DeepSeek-R1",
+ object="chat.completion.chunk",
choices=[
- azure.ai.inference.models.StreamingChatChoiceUpdate(
- delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(content=None, role=None),
+ openai.types.chat.chat_completion_chunk.Choice(
+ delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(content=None, role=None),
index=0,
finish_reason="stop",
+ # Only Azure includes content_filter_results
+ content_filter_results={},
)
],
+ created=1703462735,
+ model="DeepSeek-R1",
)
)
@@ -75,28 +91,29 @@ async def __anext__(self):
else:
raise StopAsyncIteration
- async def mock_complete(*args, **kwargs):
+ async def mock_acreate(*args, **kwargs):
# Only mock a stream=True completion
last_message = kwargs.get("messages")[-1]["content"]
if last_message == "What is the capital of France?":
- return AsyncChatCompletionIterator("The capital of France is Paris.")
+ return AsyncChatCompletionIterator(" hmm The capital of France is Paris.")
elif last_message == "What is the capital of Germany?":
- return AsyncChatCompletionIterator("The capital of Germany is Berlin.")
+ return AsyncChatCompletionIterator(" hmm The capital of Germany is Berlin.")
else:
raise ValueError(f"Unexpected message: {last_message}")
- monkeypatch.setattr("azure.ai.inference.aio.ChatCompletionsClient.complete", mock_complete)
+ monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate)
@pytest.fixture
def mock_defaultazurecredential(monkeypatch):
- monkeypatch.setattr("azure.identity.aio.DefaultAzureCredential", mock_cred.MockAzureCredential)
+ monkeypatch.setattr("azure.identity.aio.AzureDeveloperCliCredential", mock_cred.MockAzureCredential)
monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential)
@pytest_asyncio.fixture
async def client(monkeypatch, mock_openai_chatcompletion, mock_defaultazurecredential):
monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com")
+ monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id")
quart_app = quartapp.create_app(testing=True)
diff --git a/tests/mock_cred.py b/tests/mock_cred.py
index d8ea4a3..53a0e58 100644
--- a/tests/mock_cred.py
+++ b/tests/mock_cred.py
@@ -1,5 +1,10 @@
+import azure.core.credentials
import azure.core.credentials_async
class MockAzureCredential(azure.core.credentials_async.AsyncTokenCredential):
- pass
+ async def get_token(self, *scopes, **kwargs):
+ return azure.core.credentials.AccessToken(
+ token="mock_token",
+ expires_on=1703462735,
+ )
diff --git a/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines b/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines
index b4f3f44..50730aa 100644
--- a/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines
+++ b/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines
@@ -1,4 +1,5 @@
-{"delta": {"content": "The", "reasoning_content": null, "role": "assistant"}}
+{"delta": {"content": null, "reasoning_content": " hmm", "role": "assistant"}}
+{"delta": {"content": " The", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " capital", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " of", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " France", "reasoning_content": null, "role": "assistant"}}
diff --git a/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines b/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines
index ba8dd58..42d9a6c 100644
--- a/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines
+++ b/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines
@@ -1,4 +1,5 @@
-{"delta": {"content": "The", "reasoning_content": null, "role": "assistant"}}
+{"delta": {"content": null, "reasoning_content": " hmm", "role": "assistant"}}
+{"delta": {"content": " The", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " capital", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " of", "reasoning_content": null, "role": "assistant"}}
{"delta": {"content": " Germany", "reasoning_content": null, "role": "assistant"}}
diff --git a/tests/test_app.py b/tests/test_app.py
index a0e7259..abe6070 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -49,11 +49,12 @@ async def test_chat_stream_text_history(client, snapshot):
async def test_openai_managedidentity(monkeypatch):
monkeypatch.setenv("AZURE_CLIENT_ID", "test-client-id")
monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com")
+ monkeypatch.setenv("RUNNING_IN_PRODUCTION", "true")
monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential)
quart_app = quartapp.create_app(testing=True)
async with quart_app.test_app():
- assert not isinstance(quart_app.blueprints["chat"].ai_client._config.credential, AzureKeyCredential)
- assert isinstance(quart_app.blueprints["chat"].ai_client._config.credential, AsyncTokenCredential)
+ assert not isinstance(quart_app.blueprints["chat"].azure_credential, AzureKeyCredential)
+ assert isinstance(quart_app.blueprints["chat"].azure_credential, AsyncTokenCredential)