Skip to content

Commit 08ab062

Browse files
authored
Merge pull request #10 from Azure-Samples/openaisdk
Port from azure-ai-inference to openAI package
2 parents e58ad7c + 1b9a1b7 commit 08ab062

File tree

8 files changed

+136
-89
lines changed

8 files changed

+136
-89
lines changed

src/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ dependencies = [
1313
"httptools",
1414
# Used by uvicorn for reload functionality
1515
"watchfiles",
16-
"azure-ai-inference",
1716
"azure-identity",
17+
"openai",
1818
"aiohttp",
1919
"python-dotenv",
2020
"pyyaml"

src/quartapp/chat.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import json
22
import os
33

4-
from azure.ai.inference.aio import ChatCompletionsClient
5-
from azure.ai.inference.models import SystemMessage
6-
from azure.identity.aio import (
7-
AzureDeveloperCliCredential,
8-
ChainedTokenCredential,
9-
ManagedIdentityCredential,
10-
)
4+
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential, get_bearer_token_provider
5+
from openai import AsyncOpenAI
116
from quart import (
127
Blueprint,
138
Response,
@@ -22,38 +17,34 @@
2217

2318
@bp.before_app_serving
2419
async def configure_openai():
25-
# Use ManagedIdentityCredential with the client_id for user-assigned managed identities
26-
user_assigned_managed_identity_credential = ManagedIdentityCredential(client_id=os.getenv("AZURE_CLIENT_ID"))
27-
28-
# Use AzureDeveloperCliCredential with the current tenant.
29-
azure_dev_cli_credential = AzureDeveloperCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID"), process_timeout=60)
30-
31-
# Create a ChainedTokenCredential with ManagedIdentityCredential and AzureDeveloperCliCredential
32-
# - ManagedIdentityCredential is used for deployment on Azure Container Apps
33-
34-
# - AzureDeveloperCliCredential is used for local development
35-
# The order of the credentials is important, as the first valid token is used
36-
# For more information check out:
37-
38-
# https://learn.microsoft.com/azure/developer/python/sdk/authentication/credential-chains?tabs=ctc#chainedtokencredential-overview
39-
azure_credential = ChainedTokenCredential(user_assigned_managed_identity_credential, azure_dev_cli_credential)
40-
current_app.logger.info("Using Azure OpenAI with credential")
41-
42-
if not os.getenv("AZURE_INFERENCE_ENDPOINT"):
43-
raise ValueError("AZURE_INFERENCE_ENDPOINT is required for Azure OpenAI")
20+
if os.getenv("RUNNING_IN_PRODUCTION"):
21+
client_id = os.environ["AZURE_CLIENT_ID"]
22+
current_app.logger.info("Using Azure OpenAI with managed identity credential for client ID: %s", client_id)
23+
bp.azure_credential = ManagedIdentityCredential(client_id=client_id)
24+
else:
25+
tenant_id = os.environ["AZURE_TENANT_ID"]
26+
current_app.logger.info("Using Azure OpenAI with Azure Developer CLI credential for tenant ID: %s", tenant_id)
27+
bp.azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id)
28+
29+
# Get the token provider for Azure OpenAI based on the selected Azure credential
30+
bp.openai_token_provider = get_bearer_token_provider(
31+
bp.azure_credential, "https://cognitiveservices.azure.com/.default"
32+
)
4433

4534
# Create the Asynchronous Azure OpenAI client
46-
bp.ai_client = ChatCompletionsClient(
47-
endpoint=os.environ["AZURE_INFERENCE_ENDPOINT"],
48-
credential=azure_credential,
49-
credential_scopes=["https://cognitiveservices.azure.com/.default"],
50-
model="DeepSeek-R1",
35+
bp.openai_client = AsyncOpenAI(
36+
base_url=os.environ["AZURE_INFERENCE_ENDPOINT"],
37+
api_key=await bp.openai_token_provider(),
38+
default_query={"api-version": "2024-05-01-preview"},
5139
)
5240

41+
# Set the model name to the Azure OpenAI model deployment name
42+
bp.openai_model = os.getenv("AZURE_DEEPSEEK_DEPLOYMENT")
43+
5344

5445
@bp.after_app_serving
5546
async def shutdown_openai():
56-
await bp.ai_client.close()
47+
await bp.openai_client.close()
5748

5849

5950
@bp.get("/")
@@ -69,15 +60,20 @@ async def chat_handler():
6960
async def response_stream():
7061
# This sends all messages, so API request may exceed token limits
7162
all_messages = [
72-
SystemMessage(content="You are a helpful assistant."),
63+
{"role": "system", "content": "You are a helpful assistant."},
7364
] + request_messages
7465

75-
client: ChatCompletionsClient = bp.ai_client
76-
result = await client.complete(messages=all_messages, max_tokens=2048, stream=True)
66+
bp.openai_client.api_key = await bp.openai_token_provider()
67+
chat_coroutine = bp.openai_client.chat.completions.create(
68+
# Azure Open AI takes the deployment name as the model name
69+
model=bp.openai_model,
70+
messages=all_messages,
71+
stream=True,
72+
)
7773

7874
try:
7975
is_thinking = False
80-
async for update in result:
76+
async for update in await chat_coroutine:
8177
if update.choices:
8278
content = update.choices[0].delta.content
8379
if content == "<think>":
@@ -103,4 +99,4 @@ async def response_stream():
10399
current_app.logger.error(e)
104100
yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n"
105101

106-
return Response(response_stream(), mimetype="application/json")
102+
return Response(response_stream())

src/requirements.txt

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with Python 3.11
2+
# This file is autogenerated by pip-compile with Python 3.12
33
# by the following command:
44
#
55
# pip-compile --output-file=requirements.txt pyproject.toml
@@ -12,24 +12,28 @@ aiohttp==3.10.11
1212
# via quartapp (pyproject.toml)
1313
aiosignal==1.3.1
1414
# via aiohttp
15+
annotated-types==0.7.0
16+
# via pydantic
1517
anyio==4.6.0
16-
# via watchfiles
18+
# via
19+
# httpx
20+
# openai
21+
# watchfiles
1722
attrs==24.2.0
1823
# via aiohttp
19-
azure-ai-inference==1.0.0b8
20-
# via quartapp (pyproject.toml)
2124
azure-core==1.31.0
22-
# via
23-
# azure-ai-inference
24-
# azure-identity
25+
# via azure-identity
2526
azure-identity==1.19.0
2627
# via quartapp (pyproject.toml)
2728
blinker==1.8.2
2829
# via
2930
# flask
3031
# quart
3132
certifi==2024.8.30
32-
# via requests
33+
# via
34+
# httpcore
35+
# httpx
36+
# requests
3337
cffi==1.17.1
3438
# via cryptography
3539
charset-normalizer==3.4.0
@@ -44,6 +48,8 @@ cryptography==44.0.1
4448
# azure-identity
4549
# msal
4650
# pyjwt
51+
distro==1.9.0
52+
# via openai
4753
flask==3.0.3
4854
# via quart
4955
frozenlist==1.4.1
@@ -54,26 +60,30 @@ gunicorn==23.0.0
5460
# via quartapp (pyproject.toml)
5561
h11==0.14.0
5662
# via
63+
# httpcore
5764
# hypercorn
5865
# uvicorn
5966
# wsproto
6067
h2==4.1.0
6168
# via hypercorn
6269
hpack==4.0.0
6370
# via h2
71+
httpcore==1.0.7
72+
# via httpx
6473
httptools==0.6.4
6574
# via quartapp (pyproject.toml)
75+
httpx==0.28.1
76+
# via openai
6677
hypercorn==0.17.3
6778
# via quart
6879
hyperframe==6.0.1
6980
# via h2
7081
idna==3.10
7182
# via
7283
# anyio
84+
# httpx
7385
# requests
7486
# yarl
75-
isodate==0.7.2
76-
# via azure-ai-inference
7787
itsdangerous==2.2.0
7888
# via
7989
# flask
@@ -82,6 +92,8 @@ jinja2==3.1.6
8292
# via
8393
# flask
8494
# quart
95+
jiter==0.9.0
96+
# via openai
8597
markupsafe==3.0.1
8698
# via
8799
# jinja2
@@ -97,6 +109,8 @@ multidict==6.1.0
97109
# via
98110
# aiohttp
99111
# yarl
112+
openai==1.66.2
113+
# via quartapp (pyproject.toml)
100114
packaging==24.1
101115
# via gunicorn
102116
portalocker==2.10.1
@@ -107,8 +121,14 @@ propcache==0.2.0
107121
# via yarl
108122
pycparser==2.22
109123
# via cffi
124+
pydantic==2.10.6
125+
# via openai
126+
pydantic-core==2.27.2
127+
# via pydantic
110128
pyjwt[crypto]==2.9.0
111-
# via msal
129+
# via
130+
# msal
131+
# pyjwt
112132
python-dotenv==1.0.1
113133
# via quartapp (pyproject.toml)
114134
pyyaml==6.0.2
@@ -122,12 +142,18 @@ requests==2.32.3
122142
six==1.16.0
123143
# via azure-core
124144
sniffio==1.3.1
125-
# via anyio
145+
# via
146+
# anyio
147+
# openai
148+
tqdm==4.67.1
149+
# via openai
126150
typing-extensions==4.12.2
127151
# via
128-
# azure-ai-inference
129152
# azure-core
130153
# azure-identity
154+
# openai
155+
# pydantic
156+
# pydantic-core
131157
urllib3==2.2.3
132158
# via requests
133159
uvicorn==0.32.0

tests/conftest.py

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import azure.ai.inference.models
1+
import openai
22
import pytest
33
import pytest_asyncio
44

@@ -13,54 +13,70 @@ class AsyncChatCompletionIterator:
1313
def __init__(self, answer: str):
1414
self.chunk_index = 0
1515
self.chunks = [
16-
azure.ai.inference.models.StreamingChatCompletionsUpdate(
17-
id="test-123",
18-
created=1703462735,
19-
model="DeepSeek-R1",
20-
choices=[
21-
azure.ai.inference.models.StreamingChatChoiceUpdate(
22-
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
23-
content=None, role="assistant"
24-
),
25-
index=0,
26-
finish_reason=None,
27-
)
16+
openai.types.chat.ChatCompletionChunk(
17+
object="chat.completion.chunk",
18+
choices=[],
19+
id="",
20+
created=0,
21+
model="",
22+
prompt_filter_results=[
23+
{
24+
"prompt_index": 0,
25+
"content_filter_results": {
26+
"hate": {"filtered": False, "severity": "safe"},
27+
"self_harm": {"filtered": False, "severity": "safe"},
28+
"sexual": {"filtered": False, "severity": "safe"},
29+
"violence": {"filtered": False, "severity": "safe"},
30+
},
31+
}
2832
],
29-
),
33+
)
3034
]
3135
answer_deltas = answer.split(" ")
3236
for answer_index, answer_delta in enumerate(answer_deltas):
33-
# Completion chunks include whitespace, so we need to add it back in
34-
if answer_index > 0:
37+
# Text completion chunks include whitespace, so we need to add it back in
38+
if answer_index > 0 and answer_delta != "</think>":
3539
answer_delta = " " + answer_delta
3640
self.chunks.append(
37-
azure.ai.inference.models.StreamingChatCompletionsUpdate(
41+
openai.types.chat.ChatCompletionChunk(
3842
id="test-123",
39-
created=1703462735,
40-
model="DeepSeek-R1",
43+
object="chat.completion.chunk",
4144
choices=[
42-
azure.ai.inference.models.StreamingChatChoiceUpdate(
43-
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
44-
content=answer_delta, role=None
45+
openai.types.chat.chat_completion_chunk.Choice(
46+
delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
47+
role=None, content=answer_delta
4548
),
46-
index=0,
4749
finish_reason=None,
50+
index=0,
51+
logprobs=None,
52+
# Only Azure includes content_filter_results
53+
content_filter_results={
54+
"hate": {"filtered": False, "severity": "safe"},
55+
"self_harm": {"filtered": False, "severity": "safe"},
56+
"sexual": {"filtered": False, "severity": "safe"},
57+
"violence": {"filtered": False, "severity": "safe"},
58+
},
4859
)
4960
],
61+
created=1703462735,
62+
model="DeepSeek-R1",
5063
)
5164
)
5265
self.chunks.append(
53-
azure.ai.inference.models.StreamingChatCompletionsUpdate(
66+
openai.types.chat.ChatCompletionChunk(
5467
id="test-123",
55-
created=1703462735,
56-
model="DeepSeek-R1",
68+
object="chat.completion.chunk",
5769
choices=[
58-
azure.ai.inference.models.StreamingChatChoiceUpdate(
59-
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(content=None, role=None),
70+
openai.types.chat.chat_completion_chunk.Choice(
71+
delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(content=None, role=None),
6072
index=0,
6173
finish_reason="stop",
74+
# Only Azure includes content_filter_results
75+
content_filter_results={},
6276
)
6377
],
78+
created=1703462735,
79+
model="DeepSeek-R1",
6480
)
6581
)
6682

@@ -75,28 +91,29 @@ async def __anext__(self):
7591
else:
7692
raise StopAsyncIteration
7793

78-
async def mock_complete(*args, **kwargs):
94+
async def mock_acreate(*args, **kwargs):
7995
# Only mock a stream=True completion
8096
last_message = kwargs.get("messages")[-1]["content"]
8197
if last_message == "What is the capital of France?":
82-
return AsyncChatCompletionIterator("The capital of France is Paris.")
98+
return AsyncChatCompletionIterator("<think> hmm </think> The capital of France is Paris.")
8399
elif last_message == "What is the capital of Germany?":
84-
return AsyncChatCompletionIterator("The capital of Germany is Berlin.")
100+
return AsyncChatCompletionIterator("<think> hmm </think> The capital of Germany is Berlin.")
85101
else:
86102
raise ValueError(f"Unexpected message: {last_message}")
87103

88-
monkeypatch.setattr("azure.ai.inference.aio.ChatCompletionsClient.complete", mock_complete)
104+
monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate)
89105

90106

91107
@pytest.fixture
92108
def mock_defaultazurecredential(monkeypatch):
93-
monkeypatch.setattr("azure.identity.aio.DefaultAzureCredential", mock_cred.MockAzureCredential)
109+
monkeypatch.setattr("azure.identity.aio.AzureDeveloperCliCredential", mock_cred.MockAzureCredential)
94110
monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential)
95111

96112

97113
@pytest_asyncio.fixture
98114
async def client(monkeypatch, mock_openai_chatcompletion, mock_defaultazurecredential):
99115
monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com")
116+
monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id")
100117

101118
quart_app = quartapp.create_app(testing=True)
102119

0 commit comments

Comments
 (0)