Skip to content

Commit 6279741

Browse files
authored
make opeapi client is open and close during api call instead of wait until server start up and shut down (#199)
1 parent c17826b commit 6279741

File tree

3 files changed

+103
-102
lines changed

3 files changed

+103
-102
lines changed

azure.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
name: azd-get-started-with-ai-agents
66
metadata:
7-
template: azd-get-started-with-ai-agents@2.0.2
7+
template: azd-get-started-with-ai-agents@2.0.3
88
requiredVersions:
99
azd: ">=1.14.0"
1010

src/api/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ async def lifespan(app: fastapi.FastAPI):
3030
async with (
3131
DefaultAzureCredential() as credential,
3232
AIProjectClient(endpoint=proj_endpoint, credential=credential) as project_client,
33-
project_client.get_openai_client() as openai_client,
3433
):
3534
logger.info("Created AIProjectClient")
3635

@@ -73,7 +72,6 @@ async def lifespan(app: fastapi.FastAPI):
7372

7473
app.state.ai_project = project_client
7574
app.state.agent_version_obj = agent_version_obj
76-
app.state.openai_client = openai_client
7775
yield
7876

7977
except Exception as e:

src/api/routes.py

Lines changed: 102 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def cleanup_created_at_metadata(metadata: Mapping[str, str]) -> None:
8787
min_key = min(created_at_keys, key=metadata.get)
8888
del metadata[min_key]
8989

90-
def get_ai_project(request: Request) -> AIProjectClient:
90+
def get_project_client(request: Request) -> AIProjectClient:
9191
return request.app.state.ai_project
9292

9393
def get_agent_version_obj(request: Request) -> AgentVersionObject:
9494
return request.app.state.agent_version_obj
9595

9696
def get_openai_client(request: Request) -> AsyncOpenAI:
97-
return request.app.state.openai_client
97+
return get_project_client(request).get_openai_client()
9898

9999
def get_created_at_label(message_id: str) -> str:
100100
return f"{message_id}_created_at"
@@ -200,47 +200,48 @@ async def get_result(
200200
agent: AgentVersionObject,
201201
conversation: Conversation,
202202
user_message: str,
203-
openAI: AsyncOpenAI,
203+
project_client: AIProjectClient,
204204
carrier: Dict[str, str]
205205
) -> AsyncGenerator[str, None]:
206206
ctx = TraceContextTextMapPropagator().extract(carrier=carrier)
207207
with tracer.start_as_current_span('get_result', context=ctx):
208-
logger.info(f"get_result invoked for conversation={conversation.id}")
209-
input_created_at = datetime.now(timezone.utc).timestamp()
210-
try:
211-
response = await openAI.responses.create(
212-
conversation=conversation.id,
213-
input=user_message,
214-
extra_body={"agent": AgentReference(name=agent.name, version=agent.version).as_dict()},
215-
stream=True
216-
)
217-
logger.info("Successfully created stream; starting to process events")
218-
async for event in response:
219-
if event.type == "response.created":
220-
logger.info(f"Stream response created with ID: {event.response.id}")
221-
elif event.type == "response.output_text.delta":
222-
logger.info(f"Delta: {event.delta}")
223-
stream_data = {'content': event.delta, 'type': "message"}
224-
yield serialize_sse_event(stream_data)
225-
elif event.type == "response.output_item.done" and event.item.type == "message":
226-
stream_data = await get_message_and_annotations(event.item)
227-
stream_data['type'] = "completed_message"
228-
yield serialize_sse_event(stream_data)
229-
elif event.type == "response.completed":
230-
logger.info(f"Response completed with full message: {event.response.output_text}")
231-
232-
except Exception as e:
233-
logger.exception(f"Exception in get_result: {e}")
234-
error_data = {
235-
'content': str(e),
236-
'annotations': [],
237-
'type': "completed_message"
238-
}
239-
yield serialize_sse_event(error_data)
240-
finally:
241-
stream_data = {'type': "stream_end"}
242-
await save_user_message_created_at(openAI, conversation, input_created_at)
243-
yield serialize_sse_event(stream_data)
208+
async with project_client.get_openai_client() as openai_client:
209+
logger.info(f"get_result invoked for conversation={conversation.id}")
210+
input_created_at = datetime.now(timezone.utc).timestamp()
211+
try:
212+
response = await openai_client.responses.create(
213+
conversation=conversation.id,
214+
input=user_message,
215+
extra_body={"agent": AgentReference(name=agent.name, version=agent.version).as_dict()},
216+
stream=True
217+
)
218+
logger.info("Successfully created stream; starting to process events")
219+
async for event in response:
220+
if event.type == "response.created":
221+
logger.info(f"Stream response created with ID: {event.response.id}")
222+
elif event.type == "response.output_text.delta":
223+
logger.info(f"Delta: {event.delta}")
224+
stream_data = {'content': event.delta, 'type': "message"}
225+
yield serialize_sse_event(stream_data)
226+
elif event.type == "response.output_item.done" and event.item.type == "message":
227+
stream_data = await get_message_and_annotations(event.item)
228+
stream_data['type'] = "completed_message"
229+
yield serialize_sse_event(stream_data)
230+
elif event.type == "response.completed":
231+
logger.info(f"Response completed with full message: {event.response.output_text}")
232+
233+
except Exception as e:
234+
logger.exception(f"Exception in get_result: {e}")
235+
error_data = {
236+
'content': str(e),
237+
'annotations': [],
238+
'type': "completed_message"
239+
}
240+
yield serialize_sse_event(error_data)
241+
finally:
242+
stream_data = {'type': "stream_end"}
243+
await save_user_message_created_at(openai_client, conversation, input_created_at)
244+
yield serialize_sse_event(stream_data)
244245

245246

246247

@@ -252,36 +253,37 @@ async def history(
252253
_ = auth_dependency
253254
):
254255
with tracer.start_as_current_span("chat_history"):
255-
conversation_id = request.cookies.get('conversation_id')
256-
agent_id = request.cookies.get('agent_id')
256+
async with openai_client:
257+
conversation_id = request.cookies.get('conversation_id')
258+
agent_id = request.cookies.get('agent_id')
257259

258-
# Get or create conversation using the reusable function
259-
conversation = await get_or_create_conversation(
260-
openai_client, conversation_id, agent_id, agent.id
261-
)
262-
agent_id = agent.id
263-
# Create a new message from the user's input.
264-
try:
265-
content = []
266-
items = await openai_client.conversations.items.list(conversation_id=conversation.id, order="desc", limit=16)
267-
async for item in items:
268-
if item.type == "message":
269-
formatteded_message = await get_message_and_annotations(item)
270-
formatteded_message['role'] = item.role
271-
formatteded_message['created_at'] = conversation.metadata.get(get_created_at_label(item.id), "")
272-
content.append(formatteded_message)
273-
274-
275-
logger.info(f"List message, conversation ID: {conversation_id}")
276-
response = JSONResponse(content=content)
277-
278-
# Update cookies to persist the conversation IDs.
279-
response.set_cookie("conversation_id", conversation_id)
280-
response.set_cookie("agent_id", agent_id)
281-
return response
282-
except Exception as e:
283-
logger.error(f"Error listing message: {e}")
284-
raise HTTPException(status_code=500, detail=f"Error list message: {e}")
260+
# Get or create conversation using the reusable function
261+
conversation = await get_or_create_conversation(
262+
openai_client, conversation_id, agent_id, agent.id
263+
)
264+
agent_id = agent.id
265+
# Create a new message from the user's input.
266+
try:
267+
content = []
268+
items = await openai_client.conversations.items.list(conversation_id=conversation.id, order="desc", limit=16)
269+
async for item in items:
270+
if item.type == "message":
271+
formatteded_message = await get_message_and_annotations(item)
272+
formatteded_message['role'] = item.role
273+
formatteded_message['created_at'] = conversation.metadata.get(get_created_at_label(item.id), "")
274+
content.append(formatteded_message)
275+
276+
277+
logger.info(f"List message, conversation ID: {conversation_id}")
278+
response = JSONResponse(content=content)
279+
280+
# Update cookies to persist the conversation IDs.
281+
response.set_cookie("conversation_id", conversation_id)
282+
response.set_cookie("agent_id", agent_id)
283+
return response
284+
except Exception as e:
285+
logger.error(f"Error listing message: {e}")
286+
raise HTTPException(status_code=500, detail=f"Error list message: {e}")
285287

286288
@router.get("/agent")
287289
async def get_chat_agent(
@@ -292,7 +294,7 @@ async def get_chat_agent(
292294
@router.post("/chat")
293295
async def chat(
294296
request: Request,
295-
openai_client : AsyncOpenAI = Depends(get_openai_client),
297+
project_client: AIProjectClient = Depends(get_project_client),
296298
agent: AgentVersionObject = Depends(get_agent_version_obj),
297299

298300
_ = auth_dependency
@@ -301,40 +303,41 @@ async def chat(
301303
conversation_id = request.cookies.get('conversation_id')
302304
agent_id = request.cookies.get('agent_id')
303305

304-
with tracer.start_as_current_span("chat_request"):
305-
carrier = {}
306-
TraceContextTextMapPropagator().inject(carrier)
306+
carrier = {}
307+
TraceContextTextMapPropagator().inject(carrier)
307308

308-
# if the connection no longer exist or agent is changed, create a new one
309-
conversation = await get_or_create_conversation(
310-
openai_client, conversation_id, agent_id, agent.id
311-
)
312-
conversation_id = conversation.id
313-
agent_id = agent.id
309+
with tracer.start_as_current_span("chat_request"):
310+
async with project_client.get_openai_client() as openai_client:
311+
# if the connection no longer exist or agent is changed, create a new one
312+
conversation = await get_or_create_conversation(
313+
openai_client, conversation_id, agent_id, agent.id
314+
)
315+
conversation_id = conversation.id
316+
agent_id = agent.id
314317

315-
# Parse the JSON from the request.
316-
try:
317-
user_message = await request.json()
318-
except Exception as e:
319-
logger.error(f"Invalid JSON in request: {e}")
320-
raise HTTPException(status_code=400, detail=f"Invalid JSON in request: {e}")
321-
# Create a new message from the user's input.
322-
323-
# Set the Server-Sent Events (SSE) response headers.
324-
headers = {
325-
"Cache-Control": "no-cache",
326-
"Connection": "keep-alive",
327-
"Content-Type": "text/event-stream"
328-
}
329-
logger.info(f"Starting streaming response for conversation ID {conversation_id}")
318+
# Parse the JSON from the request.
319+
try:
320+
user_message = await request.json()
321+
except Exception as e:
322+
logger.error(f"Invalid JSON in request: {e}")
323+
raise HTTPException(status_code=400, detail=f"Invalid JSON in request: {e}")
324+
# Create a new message from the user's input.
325+
326+
# Set the Server-Sent Events (SSE) response headers.
327+
headers = {
328+
"Cache-Control": "no-cache",
329+
"Connection": "keep-alive",
330+
"Content-Type": "text/event-stream"
331+
}
332+
logger.info(f"Starting streaming response for conversation ID {conversation_id}")
330333

331-
# Create the streaming response using the generator.
332-
response = StreamingResponse(get_result(agent, conversation, user_message.get('message', ''), openai_client, carrier), headers=headers)
334+
# Create the streaming response using the generator.
335+
response = StreamingResponse(get_result(agent, conversation, user_message.get('message', ''), project_client, carrier), headers=headers)
333336

334-
# Update cookies to persist the conversation and agent IDs.
335-
response.set_cookie("conversation_id", conversation_id)
336-
response.set_cookie("agent_id", agent_id)
337-
return response
337+
# Update cookies to persist the conversation and agent IDs.
338+
response.set_cookie("conversation_id", conversation_id)
339+
response.set_cookie("agent_id", agent_id)
340+
return response
338341

339342
def read_file(path: str) -> str:
340343
with open(path, 'r') as file:

0 commit comments

Comments
 (0)