|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate |
| 9 | +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel |
| 10 | + |
| 11 | + |
| 12 | +class _DummyWS: |
| 13 | + def __init__(self) -> None: |
| 14 | + self.sent: list[str] = [] |
| 15 | + |
| 16 | + async def send(self, data: str) -> None: # type: ignore[override] |
| 17 | + self.sent.append(data) |
| 18 | + |
| 19 | + |
| 20 | +@pytest.mark.asyncio |
| 21 | +async def test_session_update_flattens_audio_and_modalities() -> None: |
| 22 | + model = OpenAIRealtimeWebSocketModel() |
| 23 | + # Inject a dummy websocket so send() works without a network |
| 24 | + dummy = _DummyWS() |
| 25 | + model._websocket = dummy # type: ignore[attr-defined] |
| 26 | + |
| 27 | + settings = { |
| 28 | + "model_name": "gpt-realtime", |
| 29 | + "modalities": ["text", "audio"], |
| 30 | + "input_audio_format": "pcm16", |
| 31 | + "input_audio_transcription": {"model": "gpt-4o-mini-transcribe"}, |
| 32 | + "output_audio_format": "pcm16", |
| 33 | + "turn_detection": {"type": "semantic_vad", "threshold": 0.5}, |
| 34 | + "voice": "ash", |
| 35 | + "speed": 1.0, |
| 36 | + "max_output_tokens": 2048, |
| 37 | + } |
| 38 | + |
| 39 | + await model.send_event(RealtimeModelSendSessionUpdate(session_settings=settings)) |
| 40 | + |
| 41 | + # One session.update should have been sent |
| 42 | + assert dummy.sent, "no websocket messages were sent" |
| 43 | + payload = json.loads(dummy.sent[-1]) |
| 44 | + assert payload["type"] == "session.update" |
| 45 | + session = payload["session"] |
| 46 | + |
| 47 | + # GA expects flattened fields, not session.audio or session.type |
| 48 | + assert "audio" not in session |
| 49 | + assert "type" not in session |
| 50 | + # Modalities field is named 'modalities' in GA |
| 51 | + assert session.get("modalities") == ["text", "audio"] |
| 52 | + # Audio fields flattened |
| 53 | + assert session.get("input_audio_format") == "pcm16" |
| 54 | + assert session.get("output_audio_format") == "pcm16" |
| 55 | + assert isinstance(session.get("input_audio_transcription"), dict) |
| 56 | + assert isinstance(session.get("turn_detection"), dict) |
| 57 | + # Token field name normalized |
| 58 | + assert session.get("max_response_output_tokens") == 2048 |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.asyncio |
| 62 | +async def test_no_auto_interrupt_on_vad_speech_started(monkeypatch: Any) -> None: |
| 63 | + model = OpenAIRealtimeWebSocketModel() |
| 64 | + |
| 65 | + called = {"interrupt": False} |
| 66 | + |
| 67 | + async def _fake_interrupt(event: Any) -> None: |
| 68 | + called["interrupt"] = True |
| 69 | + |
| 70 | + # Prevent network use; _websocket only needed for other paths |
| 71 | + model._websocket = _DummyWS() # type: ignore[attr-defined] |
| 72 | + monkeypatch.setattr(model, "_send_interrupt", _fake_interrupt) |
| 73 | + |
| 74 | + # This event previously triggered an interrupt; now it should be ignored |
| 75 | + await model._handle_ws_event({"type": "input_audio_buffer.speech_started"}) |
| 76 | + |
| 77 | + assert called["interrupt"] is False |
0 commit comments