1- import json
21import asyncio
32import base64
3+ import json
4+ import logging
45import os
56
67from dotenv import load_dotenv
7-
8- from google .genai .types import (
9- Part ,
10- Content ,
11- Blob ,
12- )
13-
14- from google .adk .runners import InMemoryRunner
8+ from example_agent .agent import root_agent
9+ from fastapi import FastAPI , WebSocket
1510from google .adk .agents import LiveRequestQueue
1611from google .adk .agents .run_config import RunConfig
12+ from google .adk .runners import InMemoryRunner
1713from google .genai import types
18-
19- from fastapi import FastAPI , WebSocket
20-
21-
22- import logging
14+ from google . genai . types import (
15+ Blob ,
16+ Content ,
17+ Part ,
18+ )
2319from starlette .websockets import WebSocketDisconnect
2420
25- from example_agent .agent import root_agent
26-
2721load_dotenv ()
2822
23+
2924async def start_agent_session (user_id : str ):
3025 """Starts an agent session"""
3126
3227 # Create a Runner
33- runner = InMemoryRunner (
34- app_name = os .getenv ("APP_NAME" ),
35- agent = root_agent
36- )
28+ runner = InMemoryRunner (app_name = os .getenv ("APP_NAME" ), agent = root_agent )
3729
3830 # Create a Session
3931 session = await runner .session_service .create_session (
@@ -44,7 +36,7 @@ async def start_agent_session(user_id: str):
4436 # Create a LiveRequestQueue for this session
4537 live_request_queue = LiveRequestQueue ()
4638
47- # Setup RunConfig
39+ # Setup RunConfig
4840 run_config = RunConfig (
4941 streaming_mode = "bidi" ,
5042 session_resumption = types .SessionResumptionConfig (transparent = True ),
@@ -56,17 +48,17 @@ async def start_agent_session(user_id: str):
5648 silence_duration_ms = 0 ,
5749 )
5850 ),
59- response_modalities = ["AUDIO" ],
51+ response_modalities = ["AUDIO" ],
6052 speech_config = types .SpeechConfig (
6153 voice_config = types .VoiceConfig (
6254 prebuilt_voice_config = types .PrebuiltVoiceConfig (
6355 voice_name = os .getenv ("AGENT_VOICE" )
6456 )
6557 ),
66- language_code = os .getenv ("AGENT_LANGUAGE" )
58+ language_code = os .getenv ("AGENT_LANGUAGE" ),
6759 ),
68- output_audio_transcription = {},
69- input_audio_transcription = {},
60+ output_audio_transcription = {},
61+ input_audio_transcription = {},
7062 )
7163
7264 # Start agent session
@@ -89,67 +81,84 @@ async def agent_to_client_messaging(websocket: WebSocket, live_events):
8981 "interrupted" : event .interrupted or False ,
9082 "parts" : [],
9183 "input_transcription" : None ,
92- "output_transcription" : None
84+ "output_transcription" : None ,
9385 }
9486
9587 if not event .content :
96- if ( message_to_send ["turn_complete" ] or message_to_send ["interrupted" ]) :
88+ if message_to_send ["turn_complete" ] or message_to_send ["interrupted" ]:
9789 await websocket .send_text (json .dumps (message_to_send ))
98- continue
90+ continue
91+
92+ transcription_text = "" .join (
93+ part .text for part in event .content .parts if part .text
94+ )
9995
100- transcription_text = "" .join (part .text for part in event .content .parts if part .text )
101-
10296 if hasattr (event .content , "role" ) and event .content .role == "user" :
10397 if transcription_text :
10498 message_to_send ["input_transcription" ] = {
10599 "text" : transcription_text ,
106- "is_final" : not event .partial
100+ "is_final" : not event .partial ,
107101 }
108-
102+
109103 elif hasattr (event .content , "role" ) and event .content .role == "model" :
110104 if transcription_text :
111105 message_to_send ["output_transcription" ] = {
112106 "text" : transcription_text ,
113- "is_final" : not event .partial
107+ "is_final" : not event .partial ,
114108 }
115- message_to_send ["parts" ].append ({"type" : "text" , "data" : transcription_text })
109+ message_to_send ["parts" ].append (
110+ {"type" : "text" , "data" : transcription_text }
111+ )
116112
117113 for part in event .content .parts :
118- if part .inline_data and part .inline_data .mime_type .startswith ("audio/pcm" ):
114+ if part .inline_data and part .inline_data .mime_type .startswith (
115+ "audio/pcm"
116+ ):
119117 audio_data = part .inline_data .data
120118 encoded_audio = base64 .b64encode (audio_data ).decode ("ascii" )
121- message_to_send ["parts" ].append ({"type" : "audio/pcm" , "data" : encoded_audio })
122-
119+ message_to_send ["parts" ].append (
120+ {"type" : "audio/pcm" , "data" : encoded_audio }
121+ )
122+
123123 elif part .function_call :
124- message_to_send ["parts" ].append ({
125- "type" : "function_call" ,
126- "data" : {
127- "name" : part .function_call .name ,
128- "args" : part .function_call .args or {}
124+ message_to_send ["parts" ].append (
125+ {
126+ "type" : "function_call" ,
127+ "data" : {
128+ "name" : part .function_call .name ,
129+ "args" : part .function_call .args or {},
130+ },
129131 }
130- } )
131-
132+ )
133+
132134 elif part .function_response :
133- message_to_send ["parts" ].append ({
134- "type" : "function_response" ,
135- "data" : {
136- "name" : part .function_response .name ,
137- "response" : part .function_response .response or {}
135+ message_to_send ["parts" ].append (
136+ {
137+ "type" : "function_response" ,
138+ "data" : {
139+ "name" : part .function_response .name ,
140+ "response" : part .function_response .response or {},
141+ },
138142 }
139- })
140-
141- if (message_to_send ["parts" ] or
142- message_to_send ["turn_complete" ] or
143- message_to_send ["interrupted" ] or
144- message_to_send ["input_transcription" ] or
145- message_to_send ["output_transcription" ]):
146-
143+ )
144+
145+ if (
146+ message_to_send ["parts" ]
147+ or message_to_send ["turn_complete" ]
148+ or message_to_send ["interrupted" ]
149+ or message_to_send ["input_transcription" ]
150+ or message_to_send ["output_transcription" ]
151+ ):
152+
147153 await websocket .send_text (json .dumps (message_to_send ))
148154
149155 except Exception as e :
150156 logging .error (f"Error in agent_to_client_messaging: { e } " )
151157
152- async def client_to_agent_messaging (websocket : WebSocket , live_request_queue : LiveRequestQueue ):
158+
159+ async def client_to_agent_messaging (
160+ websocket : WebSocket , live_request_queue : LiveRequestQueue
161+ ):
153162 """Client to agent communication"""
154163 while True :
155164 try :
@@ -165,13 +174,17 @@ async def client_to_agent_messaging(websocket: WebSocket, live_request_queue: Li
165174 elif mime_type == "audio/pcm" :
166175 data = message ["data" ]
167176 decoded_data = base64 .b64decode (data )
168- live_request_queue .send_realtime (Blob (data = decoded_data , mime_type = mime_type ))
177+ live_request_queue .send_realtime (
178+ Blob (data = decoded_data , mime_type = mime_type )
179+ )
169180
170181 elif mime_type == "image/jpeg" :
171182 data = message ["data" ]
172183 decoded_data = base64 .b64decode (data )
173- live_request_queue .send_realtime (Blob (data = decoded_data , mime_type = mime_type ))
174-
184+ live_request_queue .send_realtime (
185+ Blob (data = decoded_data , mime_type = mime_type )
186+ )
187+
175188 else :
176189 logging .warning (f"Mime type not supported: { mime_type } " )
177190
@@ -185,6 +198,7 @@ async def client_to_agent_messaging(websocket: WebSocket, live_request_queue: Li
185198
186199app = FastAPI ()
187200
201+
188202@app .websocket ("/ws/{user_id}" )
189203async def websocket_endpoint (websocket : WebSocket , user_id : str ):
190204 """Client websocket endpoint"""
@@ -211,4 +225,3 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
211225 # Close LiveRequestQueue
212226 live_request_queue .close ()
213227 print (f"Client #{ user_id } disconnected" )
214-
0 commit comments