Skip to content

Commit 39dbf33

Browse files
committed
fix
1 parent 1dd7a5f commit 39dbf33

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

src/agents/realtime/openai_realtime.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import os
88
from datetime import datetime
9-
from typing import Annotated, Any, Callable, Literal, Union, cast
9+
from typing import Annotated, Any, Callable, Literal, Optional, Union, cast
1010

1111
import pydantic
1212
import websockets
@@ -811,15 +811,21 @@ def _get_session_config(
811811
for value in [input_audio_format, input_audio_transcription, turn_detection]
812812
):
813813
input_audio_config = OpenAIRealtimeAudioInput(
814-
format=cast(Literal["pcm16", "g711_ulaw", "g711_alaw"] | None, input_audio_format),
814+
format=cast(
815+
Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]],
816+
input_audio_format,
817+
),
815818
transcription=cast(Any, input_audio_transcription),
816819
turn_detection=cast(Any, turn_detection),
817820
)
818821

819822
output_audio_config = None
820823
if any(value is not None for value in [output_audio_format, speed, voice]):
821824
output_audio_config = OpenAIRealtimeAudioOutput(
822-
format=cast(Literal["pcm16", "g711_ulaw", "g711_alaw"] | None, output_audio_format),
825+
format=cast(
826+
Optional[Literal["pcm16", "g711_ulaw", "g711_alaw"]],
827+
output_audio_format,
828+
),
823829
speed=speed,
824830
voice=voice,
825831
)
@@ -838,6 +844,7 @@ def _get_session_config(
838844
instructions=model_settings.get("instructions"),
839845
output_modalities=modalities,
840846
audio=audio_config,
847+
max_output_tokens=cast(Any, model_settings.get("max_output_tokens")),
841848
tool_choice=cast(Any, model_settings.get("tool_choice")),
842849
tools=cast(
843850
Any,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from typing import Any
5+
6+
import pytest
7+
8+
from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate
9+
from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel
10+
11+
12+
class _DummyWS:
13+
def __init__(self) -> None:
14+
self.sent: list[str] = []
15+
16+
async def send(self, data: str) -> None: # type: ignore[override]
17+
self.sent.append(data)
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_session_update_flattens_audio_and_modalities() -> None:
22+
model = OpenAIRealtimeWebSocketModel()
23+
# Inject a dummy websocket so send() works without a network
24+
dummy = _DummyWS()
25+
model._websocket = dummy # type: ignore[attr-defined]
26+
27+
settings = {
28+
"model_name": "gpt-realtime",
29+
"modalities": ["text", "audio"],
30+
"input_audio_format": "pcm16",
31+
"input_audio_transcription": {"model": "gpt-4o-mini-transcribe"},
32+
"output_audio_format": "pcm16",
33+
"turn_detection": {"type": "semantic_vad", "threshold": 0.5},
34+
"voice": "ash",
35+
"speed": 1.0,
36+
"max_output_tokens": 2048,
37+
}
38+
39+
await model.send_event(RealtimeModelSendSessionUpdate(session_settings=settings))
40+
41+
# One session.update should have been sent
42+
assert dummy.sent, "no websocket messages were sent"
43+
payload = json.loads(dummy.sent[-1])
44+
assert payload["type"] == "session.update"
45+
session = payload["session"]
46+
47+
# GA expects flattened fields, not session.audio or session.type
48+
assert "audio" not in session
49+
assert "type" not in session
50+
# Modalities field is named 'modalities' in GA
51+
assert session.get("modalities") == ["text", "audio"]
52+
# Audio fields flattened
53+
assert session.get("input_audio_format") == "pcm16"
54+
assert session.get("output_audio_format") == "pcm16"
55+
assert isinstance(session.get("input_audio_transcription"), dict)
56+
assert isinstance(session.get("turn_detection"), dict)
57+
# Token field name normalized
58+
assert session.get("max_response_output_tokens") == 2048
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_no_auto_interrupt_on_vad_speech_started(monkeypatch: Any) -> None:
63+
model = OpenAIRealtimeWebSocketModel()
64+
65+
called = {"interrupt": False}
66+
67+
async def _fake_interrupt(event: Any) -> None:
68+
called["interrupt"] = True
69+
70+
# Prevent network use; _websocket only needed for other paths
71+
model._websocket = _DummyWS() # type: ignore[attr-defined]
72+
monkeypatch.setattr(model, "_send_interrupt", _fake_interrupt)
73+
74+
# This event previously triggered an interrupt; now it should be ignored
75+
await model._handle_ws_event({"type": "input_audio_buffer.speech_started"})
76+
77+
assert called["interrupt"] is False

0 commit comments

Comments
 (0)