Skip to content

Commit 2f07eda

Browse files
committed
Port back to openAI SDK
1 parent a4a6682 commit 2f07eda

File tree

3 files changed

+79
-51
lines changed

3 files changed

+79
-51
lines changed

src/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ dependencies = [
1313
"httptools",
1414
# Used by uvicorn for reload functionality
1515
"watchfiles",
16-
"azure-ai-inference",
1716
"azure-identity",
17+
"openai",
1818
"aiohttp",
1919
"python-dotenv",
2020
"pyyaml"

src/quartapp/chat.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import json
22
import os
3+
import time
34

4-
from azure.ai.inference.aio import ChatCompletionsClient
5-
from azure.ai.inference.models import SystemMessage
6-
from azure.identity.aio import (
7-
AzureDeveloperCliCredential,
8-
ChainedTokenCredential,
9-
ManagedIdentityCredential,
10-
)
5+
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential
6+
from openai import AsyncOpenAI
117
from quart import (
128
Blueprint,
139
Response,
@@ -22,45 +18,47 @@
2218

2319
@bp.before_app_serving
2420
async def configure_openai():
25-
# Use ManagedIdentityCredential with the client_id for user-assigned managed identities
26-
user_assigned_managed_identity_credential = ManagedIdentityCredential(client_id=os.getenv("AZURE_CLIENT_ID"))
27-
28-
# Use AzureDeveloperCliCredential with the current tenant.
29-
azure_dev_cli_credential = AzureDeveloperCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID"), process_timeout=60)
30-
31-
# Create a ChainedTokenCredential with ManagedIdentityCredential and AzureDeveloperCliCredential
32-
# - ManagedIdentityCredential is used for deployment on Azure Container Apps
33-
34-
# - AzureDeveloperCliCredential is used for local development
35-
# The order of the credentials is important, as the first valid token is used
36-
# For more information check out:
37-
38-
# https://learn.microsoft.com/azure/developer/python/sdk/authentication/credential-chains?tabs=ctc#chainedtokencredential-overview
39-
azure_credential = ChainedTokenCredential(user_assigned_managed_identity_credential, azure_dev_cli_credential)
40-
current_app.logger.info("Using Azure OpenAI with credential")
41-
42-
if not os.getenv("AZURE_INFERENCE_ENDPOINT"):
43-
raise ValueError("AZURE_INFERENCE_ENDPOINT is required for Azure OpenAI")
21+
if os.getenv("RUNNING_IN_PRODUCTION"):
22+
client_id = os.environ["AZURE_CLIENT_ID"]
23+
current_app.logger.info("Using Azure OpenAI with managed identity credential for client ID: %s", client_id)
24+
bp.azure_credential = ManagedIdentityCredential(client_id=client_id)
25+
else:
26+
tenant_id = os.environ["AZURE_TENANT_ID"]
27+
current_app.logger.info("Using Azure OpenAI with Azure Developer CLI credential for tenant ID: %s", tenant_id)
28+
bp.azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id)
29+
30+
# Get the token provider for Azure OpenAI based on the selected Azure credential
31+
bp.openai_token = await bp.azure_credential.get_token("https://cognitiveservices.azure.com/.default")
4432

4533
# Create the Asynchronous Azure OpenAI client
46-
bp.ai_client = ChatCompletionsClient(
47-
endpoint=os.environ["AZURE_INFERENCE_ENDPOINT"],
48-
credential=azure_credential,
49-
credential_scopes=["https://cognitiveservices.azure.com/.default"],
50-
model="DeepSeek-R1",
34+
bp.openai_client = AsyncOpenAI(
35+
base_url=os.environ["AZURE_INFERENCE_ENDPOINT"],
36+
api_key=bp.openai_token.token,
37+
default_query={"api-version": "2024-05-01-preview"},
5138
)
5239

40+
# Set the model name to the Azure OpenAI model deployment name
41+
bp.openai_model = os.getenv("AZURE_DEEPSEEK_DEPLOYMENT")
42+
5343

5444
@bp.after_app_serving
5545
async def shutdown_openai():
56-
await bp.ai_client.close()
46+
await bp.openai_client.close()
5747

5848

5949
@bp.get("/")
6050
async def index():
6151
return await render_template("index.html")
6252

6353

54+
@bp.before_request
55+
async def maybe_refresh_token():
56+
if bp.openai_token.expires_on < (time.time() + 60):
57+
current_app.logger.info("Token is expired, refreshing token.")
58+
openai_token = await bp.azure_credential.get_token("https://cognitiveservices.azure.com/.default")
59+
bp.openai_client.api_key = openai_token.token
60+
61+
6462
@bp.post("/chat/stream")
6563
async def chat_handler():
6664
request_messages = (await request.get_json())["messages"]
@@ -69,15 +67,19 @@ async def chat_handler():
6967
async def response_stream():
7068
# This sends all messages, so API request may exceed token limits
7169
all_messages = [
72-
SystemMessage(content="You are a helpful assistant."),
70+
{"role": "system", "content": "You are a helpful assistant."},
7371
] + request_messages
7472

75-
client: ChatCompletionsClient = bp.ai_client
76-
result = await client.complete(messages=all_messages, max_tokens=2048, stream=True)
73+
chat_coroutine = bp.openai_client.chat.completions.create(
74+
# Azure Open AI takes the deployment name as the model name
75+
model=bp.openai_model,
76+
messages=all_messages,
77+
stream=True,
78+
)
7779

7880
try:
7981
is_thinking = False
80-
async for update in result:
82+
async for update in await chat_coroutine:
8183
if update.choices:
8284
content = update.choices[0].delta.content
8385
if content == "<think>":
@@ -103,4 +105,4 @@ async def response_stream():
103105
current_app.logger.error(e)
104106
yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n"
105107

106-
return Response(response_stream(), mimetype="application/json")
108+
return Response(response_stream())

src/requirements.txt

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with Python 3.11
2+
# This file is autogenerated by pip-compile with Python 3.12
33
# by the following command:
44
#
55
# pip-compile --output-file=requirements.txt pyproject.toml
@@ -12,24 +12,28 @@ aiohttp==3.10.11
1212
# via quartapp (pyproject.toml)
1313
aiosignal==1.3.1
1414
# via aiohttp
15+
annotated-types==0.7.0
16+
# via pydantic
1517
anyio==4.6.0
16-
# via watchfiles
18+
# via
19+
# httpx
20+
# openai
21+
# watchfiles
1722
attrs==24.2.0
1823
# via aiohttp
19-
azure-ai-inference==1.0.0b8
20-
# via quartapp (pyproject.toml)
2124
azure-core==1.31.0
22-
# via
23-
# azure-ai-inference
24-
# azure-identity
25+
# via azure-identity
2526
azure-identity==1.19.0
2627
# via quartapp (pyproject.toml)
2728
blinker==1.8.2
2829
# via
2930
# flask
3031
# quart
3132
certifi==2024.8.30
32-
# via requests
33+
# via
34+
# httpcore
35+
# httpx
36+
# requests
3337
cffi==1.17.1
3438
# via cryptography
3539
charset-normalizer==3.4.0
@@ -44,6 +48,8 @@ cryptography==44.0.1
4448
# azure-identity
4549
# msal
4650
# pyjwt
51+
distro==1.9.0
52+
# via openai
4753
flask==3.0.3
4854
# via quart
4955
frozenlist==1.4.1
@@ -54,26 +60,30 @@ gunicorn==23.0.0
5460
# via quartapp (pyproject.toml)
5561
h11==0.14.0
5662
# via
63+
# httpcore
5764
# hypercorn
5865
# uvicorn
5966
# wsproto
6067
h2==4.1.0
6168
# via hypercorn
6269
hpack==4.0.0
6370
# via h2
71+
httpcore==1.0.7
72+
# via httpx
6473
httptools==0.6.4
6574
# via quartapp (pyproject.toml)
75+
httpx==0.28.1
76+
# via openai
6677
hypercorn==0.17.3
6778
# via quart
6879
hyperframe==6.0.1
6980
# via h2
7081
idna==3.10
7182
# via
7283
# anyio
84+
# httpx
7385
# requests
7486
# yarl
75-
isodate==0.7.2
76-
# via azure-ai-inference
7787
itsdangerous==2.2.0
7888
# via
7989
# flask
@@ -82,6 +92,8 @@ jinja2==3.1.5
8292
# via
8393
# flask
8494
# quart
95+
jiter==0.9.0
96+
# via openai
8597
markupsafe==3.0.1
8698
# via
8799
# jinja2
@@ -97,6 +109,8 @@ multidict==6.1.0
97109
# via
98110
# aiohttp
99111
# yarl
112+
openai==1.66.2
113+
# via quartapp (pyproject.toml)
100114
packaging==24.1
101115
# via gunicorn
102116
portalocker==2.10.1
@@ -107,8 +121,14 @@ propcache==0.2.0
107121
# via yarl
108122
pycparser==2.22
109123
# via cffi
124+
pydantic==2.10.6
125+
# via openai
126+
pydantic-core==2.27.2
127+
# via pydantic
110128
pyjwt[crypto]==2.9.0
111-
# via msal
129+
# via
130+
# msal
131+
# pyjwt
112132
python-dotenv==1.0.1
113133
# via quartapp (pyproject.toml)
114134
pyyaml==6.0.2
@@ -122,12 +142,18 @@ requests==2.32.3
122142
six==1.16.0
123143
# via azure-core
124144
sniffio==1.3.1
125-
# via anyio
145+
# via
146+
# anyio
147+
# openai
148+
tqdm==4.67.1
149+
# via openai
126150
typing-extensions==4.12.2
127151
# via
128-
# azure-ai-inference
129152
# azure-core
130153
# azure-identity
154+
# openai
155+
# pydantic
156+
# pydantic-core
131157
urllib3==2.2.3
132158
# via requests
133159
uvicorn==0.32.0

0 commit comments

Comments
 (0)