diff --git a/src/pyproject.toml b/src/pyproject.toml index 228f148..6278222 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -13,8 +13,8 @@ dependencies = [ "httptools", # Used by uvicorn for reload functionality "watchfiles", - "azure-ai-inference", "azure-identity", + "openai", "aiohttp", "python-dotenv", "pyyaml" diff --git a/src/quartapp/chat.py b/src/quartapp/chat.py index 6e3dd3c..bb2d62e 100644 --- a/src/quartapp/chat.py +++ b/src/quartapp/chat.py @@ -1,13 +1,8 @@ import json import os -from azure.ai.inference.aio import ChatCompletionsClient -from azure.ai.inference.models import SystemMessage -from azure.identity.aio import ( - AzureDeveloperCliCredential, - ChainedTokenCredential, - ManagedIdentityCredential, -) +from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential, get_bearer_token_provider +from openai import AsyncOpenAI from quart import ( Blueprint, Response, @@ -22,38 +17,34 @@ @bp.before_app_serving async def configure_openai(): - # Use ManagedIdentityCredential with the client_id for user-assigned managed identities - user_assigned_managed_identity_credential = ManagedIdentityCredential(client_id=os.getenv("AZURE_CLIENT_ID")) - - # Use AzureDeveloperCliCredential with the current tenant. - azure_dev_cli_credential = AzureDeveloperCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID"), process_timeout=60) - - # Create a ChainedTokenCredential with ManagedIdentityCredential and AzureDeveloperCliCredential - # - ManagedIdentityCredential is used for deployment on Azure Container Apps - - # - AzureDeveloperCliCredential is used for local development - # The order of the credentials is important, as the first valid token is used - # For more information check out: - - # https://learn.microsoft.com/azure/developer/python/sdk/authentication/credential-chains?tabs=ctc#chainedtokencredential-overview - azure_credential = ChainedTokenCredential(user_assigned_managed_identity_credential, azure_dev_cli_credential) - current_app.logger.info("Using Azure OpenAI with credential") - - if not os.getenv("AZURE_INFERENCE_ENDPOINT"): - raise ValueError("AZURE_INFERENCE_ENDPOINT is required for Azure OpenAI") + if os.getenv("RUNNING_IN_PRODUCTION"): + client_id = os.environ["AZURE_CLIENT_ID"] + current_app.logger.info("Using Azure OpenAI with managed identity credential for client ID: %s", client_id) + bp.azure_credential = ManagedIdentityCredential(client_id=client_id) + else: + tenant_id = os.environ["AZURE_TENANT_ID"] + current_app.logger.info("Using Azure OpenAI with Azure Developer CLI credential for tenant ID: %s", tenant_id) + bp.azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id) + + # Get the token provider for Azure OpenAI based on the selected Azure credential + bp.openai_token_provider = get_bearer_token_provider( + bp.azure_credential, "https://cognitiveservices.azure.com/.default" + ) # Create the Asynchronous Azure OpenAI client - bp.ai_client = ChatCompletionsClient( - endpoint=os.environ["AZURE_INFERENCE_ENDPOINT"], - credential=azure_credential, - credential_scopes=["https://cognitiveservices.azure.com/.default"], - model="DeepSeek-R1", + bp.openai_client = AsyncOpenAI( + base_url=os.environ["AZURE_INFERENCE_ENDPOINT"], + api_key=await bp.openai_token_provider(), + default_query={"api-version": "2024-05-01-preview"}, ) + # Set the model name to the Azure OpenAI model deployment name + bp.openai_model = os.getenv("AZURE_DEEPSEEK_DEPLOYMENT") + @bp.after_app_serving async def shutdown_openai(): - await bp.ai_client.close() + await bp.openai_client.close() @bp.get("/") @@ -69,15 +60,20 @@ async def chat_handler(): async def response_stream(): # This sends all messages, so API request may exceed token limits all_messages = [ - SystemMessage(content="You are a helpful assistant."), + {"role": "system", "content": "You are a helpful assistant."}, ] + request_messages - client: ChatCompletionsClient = bp.ai_client - result = await client.complete(messages=all_messages, max_tokens=2048, stream=True) + bp.openai_client.api_key = await bp.openai_token_provider() + chat_coroutine = bp.openai_client.chat.completions.create( + # Azure Open AI takes the deployment name as the model name + model=bp.openai_model, + messages=all_messages, + stream=True, + ) try: is_thinking = False - async for update in result: + async for update in await chat_coroutine: if update.choices: content = update.choices[0].delta.content if content == "": @@ -103,4 +99,4 @@ async def response_stream(): current_app.logger.error(e) yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n" - return Response(response_stream(), mimetype="application/json") + return Response(response_stream()) diff --git a/src/requirements.txt b/src/requirements.txt index c358a96..3421755 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --output-file=requirements.txt pyproject.toml @@ -12,16 +12,17 @@ aiohttp==3.10.11 # via quartapp (pyproject.toml) aiosignal==1.3.1 # via aiohttp +annotated-types==0.7.0 + # via pydantic anyio==4.6.0 - # via watchfiles + # via + # httpx + # openai + # watchfiles attrs==24.2.0 # via aiohttp -azure-ai-inference==1.0.0b8 - # via quartapp (pyproject.toml) azure-core==1.31.0 - # via - # azure-ai-inference - # azure-identity + # via azure-identity azure-identity==1.19.0 # via quartapp (pyproject.toml) blinker==1.8.2 @@ -29,7 +30,10 @@ blinker==1.8.2 # flask # quart certifi==2024.8.30 - # via requests + # via + # httpcore + # httpx + # requests cffi==1.17.1 # via cryptography charset-normalizer==3.4.0 @@ -44,6 +48,8 @@ cryptography==44.0.1 # azure-identity # msal # pyjwt +distro==1.9.0 + # via openai flask==3.0.3 # via quart frozenlist==1.4.1 @@ -54,6 +60,7 @@ gunicorn==23.0.0 # via quartapp (pyproject.toml) h11==0.14.0 # via + # httpcore # hypercorn # uvicorn # wsproto @@ -61,8 +68,12 @@ h2==4.1.0 # via hypercorn hpack==4.0.0 # via h2 +httpcore==1.0.7 + # via httpx httptools==0.6.4 # via quartapp (pyproject.toml) +httpx==0.28.1 + # via openai hypercorn==0.17.3 # via quart hyperframe==6.0.1 @@ -70,10 +81,9 @@ hyperframe==6.0.1 idna==3.10 # via # anyio + # httpx # requests # yarl -isodate==0.7.2 - # via azure-ai-inference itsdangerous==2.2.0 # via # flask @@ -82,6 +92,8 @@ jinja2==3.1.5 # via # flask # quart +jiter==0.9.0 + # via openai markupsafe==3.0.1 # via # jinja2 @@ -97,6 +109,8 @@ multidict==6.1.0 # via # aiohttp # yarl +openai==1.66.2 + # via quartapp (pyproject.toml) packaging==24.1 # via gunicorn portalocker==2.10.1 @@ -107,8 +121,14 @@ propcache==0.2.0 # via yarl pycparser==2.22 # via cffi +pydantic==2.10.6 + # via openai +pydantic-core==2.27.2 + # via pydantic pyjwt[crypto]==2.9.0 - # via msal + # via + # msal + # pyjwt python-dotenv==1.0.1 # via quartapp (pyproject.toml) pyyaml==6.0.2 @@ -122,12 +142,18 @@ requests==2.32.3 six==1.16.0 # via azure-core sniffio==1.3.1 - # via anyio + # via + # anyio + # openai +tqdm==4.67.1 + # via openai typing-extensions==4.12.2 # via - # azure-ai-inference # azure-core # azure-identity + # openai + # pydantic + # pydantic-core urllib3==2.2.3 # via requests uvicorn==0.32.0 diff --git a/tests/conftest.py b/tests/conftest.py index c373604..c5cbac0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -import azure.ai.inference.models +import openai import pytest import pytest_asyncio @@ -13,54 +13,70 @@ class AsyncChatCompletionIterator: def __init__(self, answer: str): self.chunk_index = 0 self.chunks = [ - azure.ai.inference.models.StreamingChatCompletionsUpdate( - id="test-123", - created=1703462735, - model="DeepSeek-R1", - choices=[ - azure.ai.inference.models.StreamingChatChoiceUpdate( - delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate( - content=None, role="assistant" - ), - index=0, - finish_reason=None, - ) + openai.types.chat.ChatCompletionChunk( + object="chat.completion.chunk", + choices=[], + id="", + created=0, + model="", + prompt_filter_results=[ + { + "prompt_index": 0, + "content_filter_results": { + "hate": {"filtered": False, "severity": "safe"}, + "self_harm": {"filtered": False, "severity": "safe"}, + "sexual": {"filtered": False, "severity": "safe"}, + "violence": {"filtered": False, "severity": "safe"}, + }, + } ], - ), + ) ] answer_deltas = answer.split(" ") for answer_index, answer_delta in enumerate(answer_deltas): - # Completion chunks include whitespace, so we need to add it back in - if answer_index > 0: + # Text completion chunks include whitespace, so we need to add it back in + if answer_index > 0 and answer_delta != "": answer_delta = " " + answer_delta self.chunks.append( - azure.ai.inference.models.StreamingChatCompletionsUpdate( + openai.types.chat.ChatCompletionChunk( id="test-123", - created=1703462735, - model="DeepSeek-R1", + object="chat.completion.chunk", choices=[ - azure.ai.inference.models.StreamingChatChoiceUpdate( - delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate( - content=answer_delta, role=None + openai.types.chat.chat_completion_chunk.Choice( + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta( + role=None, content=answer_delta ), - index=0, finish_reason=None, + index=0, + logprobs=None, + # Only Azure includes content_filter_results + content_filter_results={ + "hate": {"filtered": False, "severity": "safe"}, + "self_harm": {"filtered": False, "severity": "safe"}, + "sexual": {"filtered": False, "severity": "safe"}, + "violence": {"filtered": False, "severity": "safe"}, + }, ) ], + created=1703462735, + model="DeepSeek-R1", ) ) self.chunks.append( - azure.ai.inference.models.StreamingChatCompletionsUpdate( + openai.types.chat.ChatCompletionChunk( id="test-123", - created=1703462735, - model="DeepSeek-R1", + object="chat.completion.chunk", choices=[ - azure.ai.inference.models.StreamingChatChoiceUpdate( - delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(content=None, role=None), + openai.types.chat.chat_completion_chunk.Choice( + delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(content=None, role=None), index=0, finish_reason="stop", + # Only Azure includes content_filter_results + content_filter_results={}, ) ], + created=1703462735, + model="DeepSeek-R1", ) ) @@ -75,28 +91,29 @@ async def __anext__(self): else: raise StopAsyncIteration - async def mock_complete(*args, **kwargs): + async def mock_acreate(*args, **kwargs): # Only mock a stream=True completion last_message = kwargs.get("messages")[-1]["content"] if last_message == "What is the capital of France?": - return AsyncChatCompletionIterator("The capital of France is Paris.") + return AsyncChatCompletionIterator(" hmm The capital of France is Paris.") elif last_message == "What is the capital of Germany?": - return AsyncChatCompletionIterator("The capital of Germany is Berlin.") + return AsyncChatCompletionIterator(" hmm The capital of Germany is Berlin.") else: raise ValueError(f"Unexpected message: {last_message}") - monkeypatch.setattr("azure.ai.inference.aio.ChatCompletionsClient.complete", mock_complete) + monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate) @pytest.fixture def mock_defaultazurecredential(monkeypatch): - monkeypatch.setattr("azure.identity.aio.DefaultAzureCredential", mock_cred.MockAzureCredential) + monkeypatch.setattr("azure.identity.aio.AzureDeveloperCliCredential", mock_cred.MockAzureCredential) monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential) @pytest_asyncio.fixture async def client(monkeypatch, mock_openai_chatcompletion, mock_defaultazurecredential): monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com") + monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id") quart_app = quartapp.create_app(testing=True) diff --git a/tests/mock_cred.py b/tests/mock_cred.py index d8ea4a3..53a0e58 100644 --- a/tests/mock_cred.py +++ b/tests/mock_cred.py @@ -1,5 +1,10 @@ +import azure.core.credentials import azure.core.credentials_async class MockAzureCredential(azure.core.credentials_async.AsyncTokenCredential): - pass + async def get_token(self, *scopes, **kwargs): + return azure.core.credentials.AccessToken( + token="mock_token", + expires_on=1703462735, + ) diff --git a/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines b/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines index b4f3f44..50730aa 100644 --- a/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines +++ b/tests/snapshots/test_app/test_chat_stream_text/result.jsonlines @@ -1,4 +1,5 @@ -{"delta": {"content": "The", "reasoning_content": null, "role": "assistant"}} +{"delta": {"content": null, "reasoning_content": " hmm", "role": "assistant"}} +{"delta": {"content": " The", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " capital", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " of", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " France", "reasoning_content": null, "role": "assistant"}} diff --git a/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines b/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines index ba8dd58..42d9a6c 100644 --- a/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines +++ b/tests/snapshots/test_app/test_chat_stream_text_history/result.jsonlines @@ -1,4 +1,5 @@ -{"delta": {"content": "The", "reasoning_content": null, "role": "assistant"}} +{"delta": {"content": null, "reasoning_content": " hmm", "role": "assistant"}} +{"delta": {"content": " The", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " capital", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " of", "reasoning_content": null, "role": "assistant"}} {"delta": {"content": " Germany", "reasoning_content": null, "role": "assistant"}} diff --git a/tests/test_app.py b/tests/test_app.py index a0e7259..abe6070 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -49,11 +49,12 @@ async def test_chat_stream_text_history(client, snapshot): async def test_openai_managedidentity(monkeypatch): monkeypatch.setenv("AZURE_CLIENT_ID", "test-client-id") monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com") + monkeypatch.setenv("RUNNING_IN_PRODUCTION", "true") monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential) quart_app = quartapp.create_app(testing=True) async with quart_app.test_app(): - assert not isinstance(quart_app.blueprints["chat"].ai_client._config.credential, AzureKeyCredential) - assert isinstance(quart_app.blueprints["chat"].ai_client._config.credential, AsyncTokenCredential) + assert not isinstance(quart_app.blueprints["chat"].azure_credential, AzureKeyCredential) + assert isinstance(quart_app.blueprints["chat"].azure_credential, AsyncTokenCredential)