Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,7 @@ def _on_first_frame(fut: asyncio.Future[float] | asyncio.Future[None]) -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

audio_out: _AudioOutput | None = None
tts_gen_data: _TTSGenerationData | None = None
Expand Down Expand Up @@ -2316,7 +2316,7 @@ def _on_first_frame(fut: asyncio.Future[float] | asyncio.Future[None]) -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

audio_out: _AudioOutput | None = None
if audio_output is not None:
Expand Down Expand Up @@ -2705,7 +2705,7 @@ def _on_first_frame(fut: asyncio.Future[float] | asyncio.Future[None]) -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=started_speaking_at)
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()

tasks: list[asyncio.Task[Any]] = []
tees: list[utils.aio.itertools.Tee[Any]] = []
Expand Down Expand Up @@ -3110,7 +3110,7 @@ def _on_false_interruption() -> None:
if self._audio_recognition:
self._audio_recognition.on_start_of_agent_speech(started_at=time.time())
if self.interruption_enabled:
self._interruption_by_audio_activity_enabled = False
self._disable_vad_interruption_soon()
audio_output.resume()
resumed = True
logger.debug("resumed false interrupted speech", extra={"timeout": timeout})
Expand Down Expand Up @@ -3164,7 +3164,24 @@ async def _cancel_speech_pause(
):
self._session.output.audio.resume()

def _disable_vad_interruption_soon(self) -> None:
"""disable VAD interruption after the interruption holdoff expires."""
if self._audio_recognition and self._audio_recognition.interruption_holdoff_active:

def _disable_vad() -> None:
# only disable it if the agent is still speaking
if self._session.agent_state == "speaking":
logger.trace("interruption holdoff expired")
self._interruption_by_audio_activity_enabled = False

self._audio_recognition.interruption_holdoff_cb = _disable_vad
else:
self._interruption_by_audio_activity_enabled = False

def _restore_interruption_by_audio_activity(self) -> None:
if self._audio_recognition:
self._audio_recognition.cancel_interruption_holdoff()

self._interruption_by_audio_activity_enabled = (
self._default_interruption_by_audio_activity_enabled
)
Expand Down
52 changes: 51 additions & 1 deletion livekit-agents/livekit/agents/voice/audio_recognition.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Holdoff timer not cancelled in AudioRecognition.aclose()

The _interruption_holdoff_timer (asyncio.TimerHandle) is not cancelled in AudioRecognition.aclose() (audio_recognition.py:492). The stop() method (which would cancel it via update_interruption_detection(None)cancel_interruption_holdoff()) is never called on the aclose() path — agent_activity.py:841-842 calls aclose() directly. If the agent is mid-speech when the session closes, the holdoff timer can fire after aclose() completes, executing the _disable_vad callback which accesses self._session.agent_state and mutates self._interruption_by_audio_activity_enabled on a closed activity.

(Refers to lines 492-514)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import time
from collections import deque
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol

Expand Down Expand Up @@ -185,6 +185,12 @@ def __init__(
self._interruption_enabled: bool = interruption_detection is not None and vad is not None
self._agent_speaking: bool = False

self._interruption_holdoff_duration: float | None = session.options.interruption.get(
"holdoff_duration"
)
self._interruption_holdoff_timer: asyncio.TimerHandle | None = None
self._interruption_holdoff_done_callback: Callable[[], None] | None = None

self._user_turn_span: trace.Span | None = None
self._closing = asyncio.Event()

Expand Down Expand Up @@ -236,14 +242,53 @@ def adaptive_interruption_active(self) -> bool:
and not self._interruption_ch.closed
)

# region: interruption holdoff

@property
def interruption_holdoff_active(self) -> bool:
return self._interruption_holdoff_timer is not None

@property
def interruption_holdoff_cb(self) -> Callable[[], None] | None:
return self._interruption_holdoff_done_callback

@interruption_holdoff_cb.setter
def interruption_holdoff_cb(self, cb: Callable[[], None] | None) -> None:
self._interruption_holdoff_done_callback = cb

def _on_interruption_holdoff_expired(self) -> None:
self._interruption_holdoff_timer = None
cb, self._interruption_holdoff_done_callback = (
self._interruption_holdoff_done_callback,
None,
)
if cb is not None:
cb()

def cancel_interruption_holdoff(self) -> None:
if self._interruption_holdoff_timer is not None:
self._interruption_holdoff_timer.cancel()
self._interruption_holdoff_timer = None
self._interruption_holdoff_done_callback = None

# endregion

def on_start_of_agent_speech(self, started_at: float) -> None:
self._agent_speaking = True
self._endpointing.on_start_of_agent_speech(started_at=started_at)

if self._interruption_holdoff_duration:
self.cancel_interruption_holdoff()
self._interruption_holdoff_timer = asyncio.get_running_loop().call_later(
self._interruption_holdoff_duration, self._on_interruption_holdoff_expired
)

if self.adaptive_interruption_active:
self._interruption_ch.send_nowait(_AgentSpeechStartedSentinel()) # type: ignore[union-attr]

def on_end_of_agent_speech(self, *, ignore_user_transcript_until: float) -> None:
self.cancel_interruption_holdoff()

if self._agent_speaking:
self._endpointing.on_end_of_agent_speech(ended_at=time.time())

Expand Down Expand Up @@ -554,6 +599,7 @@ def update_interruption_detection(
self._tasks.add(task)
self._interruption_atask = None
self._interruption_ch = None
self.cancel_interruption_holdoff()

self._interruption_enabled = (
self._interruption_detection is not None and self._vad is not None
Expand Down Expand Up @@ -886,6 +932,10 @@ async def _on_vad_event(self, ev: vad.VADEvent) -> None:
self._run_eou_detection(chat_ctx)

async def _on_overlap_speech_event(self, ev: inference.OverlappingSpeechEvent) -> None:
if self.interruption_holdoff_active:
logger.trace("ignoring overlap speech event during interruption holdoff")
return

if ev.is_interruption:
self._hooks.on_interruption(ev)

Expand Down
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/voice/turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class InterruptionOptions(TypedDict, total=False):
false_interruption_timeout: float | None
"""Seconds of silence after an interruption before it is
classified as false. ``None`` disables. Defaults to ``2.0``."""
holdoff_duration: float | None
"""Seconds to suppress adaptive interruption handling after the agent
starts speaking each turn to allow for easier turn correction.
``None`` disables. Defaults to ``1.0``."""


_INTERRUPTION_DEFAULTS: InterruptionOptions = {
Expand All @@ -109,6 +113,7 @@ class InterruptionOptions(TypedDict, total=False):
"min_words": 0,
"resume_false_interruption": True,
"false_interruption_timeout": 2.0,
"holdoff_duration": 1.0,
}


Expand Down
42 changes: 42 additions & 0 deletions tests/test_agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,48 @@ async def test_aec_warmup() -> None:
check_timestamp(speaking_to_listening.created_at - t_origin, 5.5, speed_factor=speed)


async def test_interruption_holdoff() -> None:
"""Interruption holdoff should not interfere with VAD-based interruption when adaptive
detection is not active. The holdoff timer runs but has no effect on the VAD path.

This validates that the holdoff_duration config is properly handled and doesn't
regress normal interruption behavior. Adaptive-specific gating
(suppressing on_interruption during holdoff) requires a live inference service.
"""
speed = 5.0
actions = FakeActions()
actions.add_user_speech(0.5, 2.5, "Tell me a story.")
actions.add_llm("Here is a long story for you ... the end.")
actions.add_tts(15.0) # playout starts at ~3.5s
# user speaks at 4.0-5.0s — within the 1s warmup window (3.5 + 1.0 = 4.5s expiry)
# VAD interruption at 4.0 + 0.5 = 4.5s (warmup does NOT block VAD)
actions.add_user_speech(4.0, 5.0, "Stop!", stt_delay=0.2)

session = create_session(
actions,
speed_factor=speed,
extra_kwargs={"aec_warmup_duration": None},
)
agent = MyAgent()

agent_state_events: list[AgentStateChangedEvent] = []
playback_finished_events: list[PlaybackFinishedEvent] = []
session.on("agent_state_changed", agent_state_events.append)
session.output.audio.on("playback_finished", playback_finished_events.append)

t_origin = await asyncio.wait_for(run_session(session, agent), timeout=SESSION_TIMEOUT)

assert len(playback_finished_events) == 1
assert playback_finished_events[0].interrupted is True

assert agent_state_events[0].new_state == "listening"
assert agent_state_events[1].new_state == "thinking"
assert agent_state_events[2].new_state == "speaking"
# VAD interruption fires normally at ~4.5s (warmup doesn't block VAD path)
speaking_to_listening = next(e for e in agent_state_events[3:] if e.new_state == "listening")
check_timestamp(speaking_to_listening.created_at - t_origin, 4.5, speed_factor=speed)


@pytest.mark.parametrize(
"preemptive_generation, expected_latency",
[
Expand Down
Loading