Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions livekit-agents/livekit/agents/beta/workflows/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class GetAddressTask(AgentTask[GetAddressResult]):
def __init__(
self,
extra_instructions: str = "",
require_confirmation: bool = True,
chat_ctx: NotGivenOr[llm.ChatContext] = NOT_GIVEN,
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool | llm.Toolset]] = NOT_GIVEN,
Expand All @@ -32,6 +31,7 @@ def __init__(
llm: NotGivenOr[llm.LLM | llm.RealtimeModel | None] = NOT_GIVEN,
tts: NotGivenOr[tts.TTS | None] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
require_confirmation: NotGivenOr[bool] = NOT_GIVEN,
) -> None:
super().__init__(
instructions=(
Expand All @@ -54,7 +54,7 @@ def __init__(
"Don't invent new addresses, stick strictly to what the user said. \n"
+ (
"Call `confirm_address` after the user confirmed the address is correct. \n"
if require_confirmation
if require_confirmation is not False
else ""
)
+ "When reading a numerical ordinal suffix (st, nd, rd, th), the number must be verbally expanded into its full, correctly pronounced word form.\n"
Expand Down Expand Up @@ -106,22 +106,26 @@ async def update_address(
address = " ".join(address_fields)
self._current_address = address

if self._require_confirmation:
return (
f"The address has been updated to {address}\n"
f"Repeat the address field by field: {address_fields} if needed\n"
f"Prompt the user for confirmation, do not call `confirm_address` directly"
)
else:
self.complete(GetAddressResult(address=self._current_address))
if self._require_confirmation is False or ctx.speech_handle.input_source.modality == "text":
if not self.done():
self.complete(GetAddressResult(address=self._current_address))
return None

return (
f"The address has been updated to {address}\n"
f"Repeat the address field by field: {address_fields} if needed\n"
f"Prompt the user for confirmation, do not call `confirm_address` directly"
)

@function_tool(flags=ToolFlag.IGNORE_ON_ENTER)
async def confirm_address(self, ctx: RunContext) -> None:
"""Call this tool when the user confirms that the address is correct."""
await ctx.wait_for_playout()

if ctx.speech_handle == self._address_update_speech_handle:
if (
ctx.speech_handle == self._address_update_speech_handle
and ctx.speech_handle.input_source.modality == "audio"
):
raise ToolError("error: the user must confirm the address explicitly")

if not self._current_address:
Expand Down
28 changes: 16 additions & 12 deletions livekit-agents/livekit/agents/beta/workflows/email_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class GetEmailTask(AgentTask[GetEmailResult]):
def __init__(
self,
extra_instructions: str = "",
require_confirmation: bool = True,
chat_ctx: NotGivenOr[llm.ChatContext] = NOT_GIVEN,
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool | llm.Toolset]] = NOT_GIVEN,
Expand All @@ -37,6 +36,7 @@ def __init__(
llm: NotGivenOr[llm.LLM | llm.RealtimeModel | None] = NOT_GIVEN,
tts: NotGivenOr[tts.TTS | None] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
require_confirmation: NotGivenOr[bool] = NOT_GIVEN,
) -> None:
super().__init__(
instructions=(
Expand All @@ -59,7 +59,7 @@ def __init__(
"Don't invent new email addresses, stick strictly to what the user said. \n"
+ (
"Call `confirm_email_address` after the user confirmed the email address is correct. \n"
if require_confirmation
if require_confirmation is not False
else ""
)
+ "If the email is unclear or invalid, or it takes too much back-and-forth, prompt for it in parts: first the part before the '@', then the domain—only if needed. \n"
Expand Down Expand Up @@ -101,23 +101,27 @@ async def update_email_address(self, email: str, ctx: RunContext) -> str | None:

self._current_email = email
separated_email = " ".join(email)
if self._require_confirmation:
return (
f"The email has been updated to {email}\n"
f"Repeat the email character by character: {separated_email} if needed\n"
f"Prompt the user for confirmation, do not call `confirm_email_address` directly"
)

else:
self.complete(GetEmailResult(email_address=email))
return None
if self._require_confirmation is False or ctx.speech_handle.input_source.modality == "text":
if not self.done():
self.complete(GetEmailResult(email_address=self._current_email))
return None # no need to continue the conversation

return (
f"The email has been updated to {email}\n"
f"Repeat the email character by character: {separated_email} if needed\n"
f"Prompt the user for confirmation, do not call `confirm_email_address` directly"
)

@function_tool(flags=ToolFlag.IGNORE_ON_ENTER)
async def confirm_email_address(self, ctx: RunContext) -> None:
"""Validates/confirms the email address provided by the user."""
await ctx.wait_for_playout()

if ctx.speech_handle == self._email_update_speech_handle:
if (
ctx.speech_handle == self._email_update_speech_handle
and ctx.speech_handle.input_source.modality == "audio"
):
raise ToolError("error: the user must confirm the email address explicitly")

if not self._current_email.strip():
Expand Down
10 changes: 8 additions & 2 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
remove_instructions,
update_instructions,
)
from .speech_handle import SpeechHandle
from .speech_handle import DEFAULT_INPUT_SOURCE, InputSource, SpeechHandle

if TYPE_CHECKING:
from ..llm import mcp
Expand Down Expand Up @@ -864,6 +864,7 @@ def _generate_reply(
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
schedule_speech: bool = True,
input_source: InputSource = DEFAULT_INPUT_SOURCE,
) -> SpeechHandle:
if (
isinstance(self.llm, llm.RealtimeModel)
Expand Down Expand Up @@ -909,6 +910,7 @@ def _generate_reply(
allow_interruptions=allow_interruptions
if is_given(allow_interruptions)
else self.allow_interruptions,
input_source=input_source,
)
self._session.emit(
"speech_created",
Expand Down Expand Up @@ -1200,7 +1202,9 @@ def _on_generation_created(self, ev: llm.GenerationCreatedEvent) -> None:
logger.warning("skipping new realtime generation, the speech scheduling is not running")
return

handle = SpeechHandle.create(allow_interruptions=self.allow_interruptions)
handle = SpeechHandle.create(
allow_interruptions=self.allow_interruptions, input_source=InputSource(modality="audio")
)
self._session.emit(
"speech_created",
SpeechCreatedEvent(speech_handle=handle, user_initiated=False, source="generate_reply"),
Expand Down Expand Up @@ -1407,6 +1411,7 @@ def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None:
user_message=user_message,
chat_ctx=chat_ctx,
schedule_speech=False,
input_source=InputSource(modality="audio"),
)

self._preemptive_generation = _PreemptiveGeneration(
Expand Down Expand Up @@ -1613,6 +1618,7 @@ async def _user_turn_completed_task(
speech_handle = self._generate_reply(
user_message=user_message,
chat_ctx=temp_mutable_chat_ctx,
input_source=InputSource(modality="audio"),
)

if self._user_turn_completed_atask != asyncio.current_task():
Expand Down
17 changes: 14 additions & 3 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .ivr import IVRActivity
from .recorder_io import RecorderIO
from .run_result import RunResult
from .speech_handle import SpeechHandle
from .speech_handle import InputSource, SpeechHandle

if TYPE_CHECKING:
from ..inference import LLMModels, STTModels, TTSModels
Expand Down Expand Up @@ -429,13 +429,19 @@ def current_agent(self) -> Agent:
def tools(self) -> list[llm.Tool | llm.Toolset]:
return self._tools

def run(self, *, user_input: str, output_type: type[Run_T] | None = None) -> RunResult[Run_T]:
def run(
self,
*,
user_input: str,
input_modality: Literal["text", "audio"] = "text",
output_type: type[Run_T] | None = None,
) -> RunResult[Run_T]:
if self._global_run_state is not None and not self._global_run_state.done():
raise RuntimeError("nested runs are not supported")

run_state = RunResult(user_input=user_input, output_type=output_type)
self._global_run_state = run_state
self.generate_reply(user_input=user_input)
self.generate_reply(user_input=user_input, input_modality=input_modality)
return run_state

@overload
Expand Down Expand Up @@ -932,6 +938,7 @@ def generate_reply(
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
chat_ctx: NotGivenOr[ChatContext] = NOT_GIVEN,
input_modality: Literal["text", "audio"] = "text",
) -> SpeechHandle:
"""Generate a reply for the agent to speak to the user.

Expand All @@ -942,6 +949,9 @@ def generate_reply(
tool_choice (NotGivenOr[llm.ToolChoice], optional): Specifies the external tool to use when
generating the reply. If generate_reply is invoked within a function_tool, defaults to "none".
allow_interruptions (NotGivenOr[bool], optional): Indicates whether the user can interrupt this speech.
chat_ctx (NotGivenOr[ChatContext], optional): The chat context to use for generating the reply.
Defaults to the chat context of the current agent if not provided.
input_modality (Literal["text", "audio"], optional): The input mode to use for generating the reply.

Returns:
SpeechHandle: A handle to the generated reply.
Expand Down Expand Up @@ -973,6 +983,7 @@ def generate_reply(
tool_choice=tool_choice,
allow_interruptions=allow_interruptions,
chat_ctx=chat_ctx,
input_source=InputSource(modality=input_modality),
)
if run_state:
run_state._watch_handle(handle)
Expand Down
26 changes: 23 additions & 3 deletions livekit-agents/livekit/agents/voice/speech_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
import contextlib
from collections.abc import Callable, Generator, Sequence
from typing import Any
from dataclasses import dataclass
from typing import Any, Literal

from opentelemetry import context as otel_context

Expand All @@ -13,6 +14,14 @@
INTERRUPTION_TIMEOUT = 5.0 # seconds


@dataclass
class InputSource:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name it InputDetails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then it's like speech.input_details.source and speech.input_details.modality? sounds good to me

modality: Literal["text", "audio"]


DEFAULT_INPUT_SOURCE = InputSource(modality="audio")


class SpeechHandle:
SPEECH_PRIORITY_LOW = 0
"""Priority for messages that should be played after all other messages in the queue"""
Expand All @@ -21,9 +30,12 @@ class SpeechHandle:
SPEECH_PRIORITY_HIGH = 10
"""Priority for important messages that should be played before others."""

def __init__(self, *, speech_id: str, allow_interruptions: bool) -> None:
def __init__(
self, *, speech_id: str, allow_interruptions: bool, input_source: InputSource
) -> None:
self._id = speech_id
self._allow_interruptions = allow_interruptions
self._input_source = input_source

self._interrupt_fut = asyncio.Future[None]()
self._done_fut = asyncio.Future[None]()
Expand Down Expand Up @@ -51,10 +63,14 @@ def _on_done(_: asyncio.Future[None]) -> None:
self._maybe_run_final_output: Any = None # kept private

@staticmethod
def create(allow_interruptions: bool = True) -> SpeechHandle:
def create(
allow_interruptions: bool = True,
input_source: InputSource = DEFAULT_INPUT_SOURCE,
) -> SpeechHandle:
return SpeechHandle(
speech_id=utils.shortuuid("speech_"),
allow_interruptions=allow_interruptions,
input_source=input_source,
)

@property
Expand All @@ -65,6 +81,10 @@ def num_steps(self) -> int:
def id(self) -> str:
return self._id

@property
def input_source(self) -> InputSource:
return self._input_source

@property
def _generation_id(self) -> str:
return f"{self._id}_{self._num_steps}"
Expand Down
20 changes: 17 additions & 3 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from unittest.mock import patch

import pytest
Expand All @@ -13,12 +14,25 @@ def _llm_model() -> llm.LLM:


@pytest.mark.asyncio
async def test_collect_email() -> None:
@pytest.mark.parametrize("input_modality", ["text", "audio"])
async def test_collect_email(input_modality: Literal["text", "audio"]) -> None:
async with _llm_model() as llm, AgentSession(llm=llm) as sess:
await sess.start(beta.workflows.GetEmailTask())

await sess.run(user_input="My email address is theo at livekit dot io?")
result = await sess.run(user_input="Yes", output_type=beta.workflows.GetEmailResult)
result = await sess.run(
user_input="My email address is theo at livekit dot io?", input_modality=input_modality
)

if input_modality == "text":
assert isinstance(result.final_output, beta.workflows.GetEmailResult)
else:
# confirmation is required for audio input
result = await sess.run(
user_input="Yes",
output_type=beta.workflows.GetEmailResult,
input_modality=input_modality,
)

assert result.final_output.email_address == "theo@livekit.io"

async with _llm_model() as llm, AgentSession(llm=llm) as sess:
Expand Down
Loading