diff --git a/samples/middle-tier/python-fastapi/rt-middle-tier/main.py b/samples/middle-tier/python-fastapi/rt-middle-tier/main.py index b1db119..4be2467 100644 --- a/samples/middle-tier/python-fastapi/rt-middle-tier/main.py +++ b/samples/middle-tier/python-fastapi/rt-middle-tier/main.py @@ -1,23 +1,26 @@ +import asyncio +import json +import os +import uuid +from typing import Literal, TypedDict, Union + +import uvicorn +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential +from dotenv import load_dotenv from fastapi import FastAPI, WebSocket from fastapi.middleware.cors import CORSMiddleware from fastapi.websockets import WebSocketState -import uvicorn -import uuid -import json -from typing import Union, Literal, TypedDict -import asyncio from loguru import logger -import os -from dotenv import load_dotenv -from azure.identity import DefaultAzureCredential -from azure.core.credentials import AzureKeyCredential from rtclient import ( InputAudioTranscription, + InputTextContentPart, + RTAudioContent, RTClient, - ServerVAD, RTInputAudioItem, RTResponse, - RTAudioContent, + ServerVAD, + UserMessageItem, ) load_dotenv() @@ -71,11 +74,20 @@ def _initialize_client(self, backend: str | None): self.logger.debug(f"Initializing RT client with backend: {backend}") if backend == "azure": - return RTClient( - url=os.getenv("AZURE_OPENAI_ENDPOINT"), - token_credential=DefaultAzureCredential(), - deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"), - ) + azure_openai_api_key = os.getenv("AZURE_OPENAI_API_KEY") + # If the Azure OpenAI API key is not provided, use the DefaultAzureCredential + if not azure_openai_api_key: + return RTClient( + url=os.getenv("AZURE_OPENAI_ENDPOINT"), + token_credential=DefaultAzureCredential(), + azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"), + ) + else: + return RTClient( + url=os.getenv("AZURE_OPENAI_ENDPOINT"), + key_credential=AzureKeyCredential(azure_openai_api_key), + azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"), + ) return RTClient( key_credential=AzureKeyCredential(os.getenv("OPENAI_API_KEY")), model=os.getenv("OPENAI_MODEL"), @@ -91,7 +103,7 @@ async def initialize(self): self.logger.debug("Configuring realtime session") await self.client.configure( modalities={"text", "audio"}, - voice="coral", + voice="alloy", input_audio_format="pcm16", input_audio_transcription=InputAudioTranscription(model="whisper-1"), turn_detection=ServerVAD(), @@ -120,13 +132,13 @@ async def handle_text_message(self, message: str): if parsed["type"] == "user_message": await self.client.send_item( - { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": parsed["text"]}], - } + UserMessageItem( + content=[InputTextContentPart(text=parsed["text"])], + ) ) + # Trigger the response generation and wait for the response await self.client.generate_response() + self.logger.debug("User message processed successfully") except Exception as error: self.logger.error(f"Failed to process user message: {error}")