5757from websockets .asyncio .client import ClientConnection
5858
5959from agents .handoffs import Handoff
60+ from agents .realtime ._default_tracker import ModelAudioTracker
6061from agents .tool import FunctionTool , Tool
6162from agents .util ._types import MaybeAwaitable
6263
7273 RealtimeModel ,
7374 RealtimeModelConfig ,
7475 RealtimeModelListener ,
76+ RealtimePlaybackState ,
77+ RealtimePlaybackTracker ,
7578)
7679from .model_events import (
7780 RealtimeModelAudioDoneEvent ,
@@ -133,11 +136,10 @@ def __init__(self) -> None:
133136 self ._websocket_task : asyncio .Task [None ] | None = None
134137 self ._listeners : list [RealtimeModelListener ] = []
135138 self ._current_item_id : str | None = None
136- self ._audio_start_time : datetime | None = None
137- self ._audio_length_ms : float = 0.0
139+ self ._audio_state_tracker : ModelAudioTracker = ModelAudioTracker ()
138140 self ._ongoing_response : bool = False
139- self ._current_audio_content_index : int | None = None
140141 self ._tracing_config : RealtimeModelTracingConfig | Literal ["auto" ] | None = None
142+ self ._playback_tracker : RealtimePlaybackTracker | None = None
141143
142144 async def connect (self , options : RealtimeModelConfig ) -> None :
143145 """Establish a connection to the model and keep it alive."""
@@ -146,6 +148,8 @@ async def connect(self, options: RealtimeModelConfig) -> None:
146148
147149 model_settings : RealtimeSessionModelSettings = options .get ("initial_model_settings" , {})
148150
151+ self ._playback_tracker = options .get ("playback_tracker" , RealtimePlaybackTracker ())
152+
149153 self .model = model_settings .get ("model_name" , self .model )
150154 api_key = await get_api_key (options .get ("api_key" ))
151155
@@ -294,47 +298,78 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294298 if event .start_response :
295299 await self ._send_raw_message (OpenAIResponseCreateEvent (type = "response.create" ))
296300
301+ def _get_playback_state (self ) -> RealtimePlaybackState :
302+ if self ._playback_tracker :
303+ return self ._playback_tracker .get_state ()
304+
305+ if last_audio_item_id := self ._audio_state_tracker .get_last_audio_item ():
306+ item_id , item_content_index = last_audio_item_id
307+ audio_state = self ._audio_state_tracker .get_state (item_id , item_content_index )
308+ if audio_state :
309+ elapsed_ms = (
310+ datetime .now () - audio_state .initial_received_time
311+ ).total_seconds () * 1000
312+ return {
313+ "current_item_id" : item_id ,
314+ "current_item_content_index" : item_content_index ,
315+ "elapsed_ms" : elapsed_ms ,
316+ }
317+
318+ return {
319+ "current_item_id" : None ,
320+ "current_item_content_index" : None ,
321+ "elapsed_ms" : None ,
322+ }
323+
297324 async def _send_interrupt (self , event : RealtimeModelSendInterrupt ) -> None :
298- if not self ._current_item_id or not self ._audio_start_time :
325+ playback_state = self ._get_playback_state ()
326+ print (f"zzz playback_state: { playback_state } " )
327+ current_item_id = playback_state .get ("current_item_id" )
328+ current_item_content_index = playback_state .get ("current_item_content_index" )
329+ elapsed_ms = playback_state .get ("elapsed_ms" )
330+ if current_item_id is None or elapsed_ms is None :
331+ print ("zzz skipping interrupt" )
332+ logger .info (
333+ "Skipping interrupt. "
334+ f"Item id: { current_item_id } , "
335+ f"elapsed ms: { elapsed_ms } , "
336+ f"content index: { current_item_content_index } "
337+ )
299338 return
300339
301- await self ._cancel_response ()
302-
303- elapsed_time_ms = (datetime .now () - self ._audio_start_time ).total_seconds () * 1000
304- if elapsed_time_ms > 0 and elapsed_time_ms < self ._audio_length_ms :
340+ current_item_content_index = current_item_content_index or 0
341+ if elapsed_ms > 0 :
342+ print ("zzz sending interrupt" )
305343 await self ._emit_event (
306344 RealtimeModelAudioInterruptedEvent (
307- item_id = self . _current_item_id ,
308- content_index = self . _current_audio_content_index or 0 ,
345+ item_id = current_item_id ,
346+ content_index = current_item_content_index ,
309347 )
310348 )
311349 converted = _ConversionHelper .convert_interrupt (
312- self . _current_item_id ,
313- self . _current_audio_content_index or 0 ,
314- int (elapsed_time_ms ),
350+ current_item_id ,
351+ current_item_content_index ,
352+ int (elapsed_ms ),
315353 )
316354 await self ._send_raw_message (converted )
355+ await self ._cancel_response ()
317356
318- self ._current_item_id = None
319- self ._audio_start_time = None
320- self ._audio_length_ms = 0.0
321- self ._current_audio_content_index = None
357+ self ._audio_state_tracker .on_interrupted ()
358+ if self ._playback_tracker :
359+ self ._playback_tracker .on_interrupted ()
322360
323361 async def _send_session_update (self , event : RealtimeModelSendSessionUpdate ) -> None :
324362 """Send a session update to the model."""
325363 await self ._update_session_config (event .session_settings )
326364
327365 async def _handle_audio_delta (self , parsed : ResponseAudioDeltaEvent ) -> None :
328366 """Handle audio delta events and update audio tracking state."""
329- self ._current_audio_content_index = parsed .content_index
330367 self ._current_item_id = parsed .item_id
331- if self ._audio_start_time is None :
332- self ._audio_start_time = datetime .now ()
333- self ._audio_length_ms = 0.0
334368
335369 audio_bytes = base64 .b64decode (parsed .delta )
336- # Calculate audio length in ms using 24KHz pcm16le
337- self ._audio_length_ms += self ._calculate_audio_length_ms (audio_bytes )
370+
371+ self ._audio_state_tracker .on_audio_delta (parsed .item_id , parsed .content_index , audio_bytes )
372+
338373 await self ._emit_event (
339374 RealtimeModelAudioEvent (
340375 data = audio_bytes ,
@@ -344,10 +379,6 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
344379 )
345380 )
346381
347- def _calculate_audio_length_ms (self , audio_bytes : bytes ) -> float :
348- """Calculate audio length in milliseconds for 24KHz PCM16LE format."""
349- return len (audio_bytes ) / 24 / 2
350-
351382 async def _handle_output_item (self , item : ConversationItem ) -> None :
352383 """Handle response output item events (function calls and messages)."""
353384 if item .type == "function_call" and item .status == "completed" :
0 commit comments