Skip to content

Commit 8028dcc

Browse files
committed
speech_handle's _wait_for_generation hangs forever if cancelled before playout
1 parent ff7c9dd commit 8028dcc

File tree

2 files changed

+229
-1
lines changed

2 files changed

+229
-1
lines changed

livekit-agents/livekit/agents/voice/speech_handle.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,12 @@ async def _wait_for_generation(self, step_idx: int = -1) -> None:
192192
if not self._generations:
193193
raise RuntimeError("cannot use wait_for_generation: no active generation is running.")
194194

195-
await asyncio.shield(self._generations[step_idx])
195+
# Race against the interrupt future to avoid hanging when speech is interrupted
196+
# before the generation completes
197+
await asyncio.wait(
198+
[asyncio.shield(self._generations[step_idx]), self._interrupt_fut],
199+
return_when=asyncio.FIRST_COMPLETED,
200+
)
196201

197202
async def _wait_for_scheduled(self) -> None:
198203
await asyncio.shield(self._scheduled_fut)

tests/test_speech_handle.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""Tests for SpeechHandle
2+
3+
This test verifies that _wait_for_generation does not hang when a speech is
4+
interrupted before the generation completes.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import asyncio
10+
11+
import pytest
12+
13+
from livekit.agents.voice.speech_handle import SpeechHandle
14+
15+
16+
class TestSpeechHandleWaitForGeneration:
17+
"""Test suite for SpeechHandle._wait_for_generation hang fix."""
18+
19+
@pytest.mark.asyncio
20+
async def test_wait_for_generation_returns_when_interrupted(self) -> None:
21+
"""Test that _wait_for_generation returns immediately when speech is interrupted.
22+
23+
_wait_for_generation should not hang waiting for the generation future.
24+
"""
25+
speech = SpeechHandle.create()
26+
27+
# Authorize generation (creates the generation future)
28+
speech._authorize_generation()
29+
30+
# Interrupt the speech before generation completes
31+
speech.interrupt()
32+
33+
# _wait_for_generation should return immediately because the speech
34+
# is interrupted, even though the generation future is not resolved
35+
try:
36+
await asyncio.wait_for(speech._wait_for_generation(), timeout=1.0)
37+
except asyncio.TimeoutError:
38+
pytest.fail(
39+
"_wait_for_generation hung after interrupt"
40+
)
41+
42+
@pytest.mark.asyncio
43+
async def test_wait_for_generation_returns_when_generation_done(self) -> None:
44+
"""Test that _wait_for_generation returns when generation completes normally."""
45+
speech = SpeechHandle.create()
46+
47+
# Authorize generation
48+
speech._authorize_generation()
49+
50+
# Mark generation done in background
51+
async def mark_done_later():
52+
await asyncio.sleep(0.1)
53+
speech._mark_generation_done()
54+
55+
asyncio.create_task(mark_done_later())
56+
57+
# Should complete when generation is done
58+
try:
59+
await asyncio.wait_for(speech._wait_for_generation(), timeout=2.0)
60+
except asyncio.TimeoutError:
61+
pytest.fail("_wait_for_generation did not return after generation was done")
62+
63+
@pytest.mark.asyncio
64+
async def test_wait_for_generation_interrupt_during_wait(self) -> None:
65+
"""Test that _wait_for_generation returns if interrupted while waiting."""
66+
speech = SpeechHandle.create()
67+
68+
# Authorize generation
69+
speech._authorize_generation()
70+
71+
# Interrupt after a short delay
72+
async def interrupt_later():
73+
await asyncio.sleep(0.1)
74+
speech.interrupt()
75+
76+
asyncio.create_task(interrupt_later())
77+
78+
# Should return when interrupt happens
79+
try:
80+
await asyncio.wait_for(speech._wait_for_generation(), timeout=2.0)
81+
except asyncio.TimeoutError:
82+
pytest.fail("_wait_for_generation hung - did not respond to interrupt")
83+
84+
assert speech.interrupted
85+
86+
@pytest.mark.asyncio
87+
async def test_multiple_speeches_with_interrupts(self) -> None:
88+
"""Test processing multiple speeches where some are interrupted.
89+
90+
Simulates the mainTask queue processing scenario.
91+
"""
92+
speeches = [SpeechHandle.create() for _ in range(3)]
93+
94+
# Interrupt the middle speech before authorization
95+
speeches[1].interrupt()
96+
97+
# Process all speeches (simulating mainTask)
98+
for speech in speeches:
99+
speech._authorize_generation()
100+
101+
# For non-interrupted speeches, mark generation done
102+
if not speech.interrupted:
103+
speech._mark_generation_done()
104+
105+
# This should not hang even for interrupted speeches
106+
try:
107+
await asyncio.wait_for(speech._wait_for_generation(), timeout=1.0)
108+
except asyncio.TimeoutError:
109+
pytest.fail(
110+
f"_wait_for_generation hung for speech {speech.id} "
111+
f"(interrupted={speech.interrupted})"
112+
)
113+
114+
@pytest.mark.asyncio
115+
async def test_wait_for_generation_raises_without_authorization(self) -> None:
116+
"""Test that _wait_for_generation raises if no generation is running."""
117+
speech = SpeechHandle.create()
118+
119+
with pytest.raises(RuntimeError, match="no active generation is running"):
120+
await speech._wait_for_generation()
121+
122+
@pytest.mark.asyncio
123+
async def test_scheduling_task_simulation(self) -> None:
124+
"""Simulate the scheduling task flow that was hanging.
125+
126+
This reproduces the exact scenario from agent_activity._scheduling_task.
127+
"""
128+
# Create a queue of speeches
129+
speech_queue: list[tuple[int, int, SpeechHandle]] = []
130+
131+
speech1 = SpeechHandle.create()
132+
speech2 = SpeechHandle.create()
133+
speech3 = SpeechHandle.create()
134+
135+
# Interrupt speech2 before it's processed (simulating interrupt while in queue)
136+
speech2.interrupt()
137+
138+
speech_queue.append((5, 1, speech1))
139+
speech_queue.append((5, 2, speech2))
140+
speech_queue.append((5, 3, speech3))
141+
142+
processed_speeches: list[str] = []
143+
144+
# Simulate scheduling_task loop
145+
async def scheduling_task():
146+
while speech_queue:
147+
_, _, speech = speech_queue.pop(0)
148+
149+
if speech.done():
150+
continue
151+
152+
speech._authorize_generation()
153+
154+
# For non-interrupted speeches, simulate generation completing
155+
if not speech.interrupted:
156+
speech._mark_generation_done()
157+
158+
# This is where the hang occurred
159+
await speech._wait_for_generation()
160+
161+
processed_speeches.append(speech.id)
162+
163+
try:
164+
await asyncio.wait_for(scheduling_task(), timeout=2.0)
165+
except asyncio.TimeoutError:
166+
pytest.fail("scheduling_task simulation hung")
167+
168+
# All speeches should have been processed without hanging
169+
assert len(processed_speeches) == 3
170+
171+
172+
class TestSpeechHandleInterrupt:
173+
"""Tests for SpeechHandle interrupt behavior."""
174+
175+
@pytest.mark.asyncio
176+
async def test_interrupt_sets_interrupted_flag(self) -> None:
177+
"""Test that interrupt() sets the interrupted property."""
178+
speech = SpeechHandle.create()
179+
180+
assert not speech.interrupted
181+
speech.interrupt()
182+
assert speech.interrupted
183+
184+
@pytest.mark.asyncio
185+
async def test_interrupt_disallowed_by_default(self) -> None:
186+
"""Test that interrupt fails when allow_interruptions is False."""
187+
speech = SpeechHandle.create(allow_interruptions=False)
188+
189+
with pytest.raises(RuntimeError, match="does not allow interruptions"):
190+
speech.interrupt()
191+
192+
@pytest.mark.asyncio
193+
async def test_force_interrupt(self) -> None:
194+
"""Test that force=True overrides allow_interruptions."""
195+
speech = SpeechHandle.create(allow_interruptions=False)
196+
197+
speech.interrupt(force=True)
198+
assert speech.interrupted
199+
200+
@pytest.mark.asyncio
201+
async def test_wait_if_not_interrupted(self) -> None:
202+
"""Test wait_if_not_interrupted returns when interrupted."""
203+
speech = SpeechHandle.create()
204+
205+
never_done: asyncio.Future[None] = asyncio.Future()
206+
207+
# Interrupt after a delay
208+
async def interrupt_later():
209+
await asyncio.sleep(0.1)
210+
speech.interrupt()
211+
212+
asyncio.create_task(interrupt_later())
213+
214+
# Should return when interrupted, not hang forever
215+
try:
216+
await asyncio.wait_for(
217+
speech.wait_if_not_interrupted([never_done]),
218+
timeout=2.0,
219+
)
220+
except asyncio.TimeoutError:
221+
pytest.fail("wait_if_not_interrupted hung despite interrupt")
222+
223+
assert speech.interrupted

0 commit comments

Comments
 (0)