Skip to content

Commit b3a2362

Browse files
osimhi213youpeshlongcwm-odsprettyprettyprettygood
authored
Sync/livekit forked (#2)
* feat(google): add VertexRAGRetrieval provider tool (#5222) * fix: ensure MCP client enter/exit run in the same task (#5223) * feat(assemblyai): add domain parameter for Medical Mode (#5208) * fix: Nova Sonic interactive context bugs and dynamic tool support (#5220) Co-authored-by: Pavas Kant <pavkan@amazon.com> * (google realtime): add gemini-3.1-flash-live-preview model (#5233) * fix(utils): improve type annotation for deprecate_params decorator (#5244) * fix: expose endpointing_opts in AgentSession.update_options() (#5243) * Fix/stt fallback adapter propagate aligned transcript (#5237) * feat(mistral): add voxtral TTS support (#5245) * feat(anthropic): support strict tool use schema (#5259) * Baseten Plugin Update: fix metadata schema, add chain_id support, and improve response parsing (#4889) --------- Co-authored-by: Yousuf Bukhari <25112850+youpesh@users.noreply.github.com> Co-authored-by: Long Chen <longch1024@gmail.com> Co-authored-by: Martin Schweiger <34636718+m-ods@users.noreply.github.com> Co-authored-by: Osman-AGI <uyguripek@gmail.com> Co-authored-by: Pavas Kant <pavkan@amazon.com> Co-authored-by: Tina Nguyen <72938484+tinalenguyen@users.noreply.github.com> Co-authored-by: Milad <129620931+miladmnasr@users.noreply.github.com> Co-authored-by: Jean Perbet <jeanperbet@icloud.com> Co-authored-by: Shaik Faizan Roshan Ali <roshan.shaik.ml@gmail.com> Co-authored-by: jiegong-fde <jie.gong@baseten.co>
1 parent 02e7616 commit b3a2362

File tree

23 files changed

+919
-147
lines changed

23 files changed

+919
-147
lines changed

livekit-agents/livekit/agents/llm/_provider_format/anthropic.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,30 @@ def _to_image_content(image: llm.ImageContent) -> dict[str, Any]:
125125
}
126126

127127

128-
def to_fnc_ctx(tool_ctx: llm.ToolContext) -> list[dict[str, Any]]:
128+
def to_fnc_ctx(tool_ctx: llm.ToolContext, *, strict: bool = True) -> list[dict[str, Any]]:
129129
schemas: list[dict[str, Any]] = []
130130
for tool in tool_ctx.function_tools.values():
131131
if isinstance(tool, llm.FunctionTool):
132-
fnc = llm.utils.build_legacy_openai_schema(tool, internally_tagged=True)
133-
schemas.append(
134-
{
135-
"name": fnc["name"],
136-
"description": fnc["description"] or "",
137-
"input_schema": fnc["parameters"],
138-
}
139-
)
132+
if strict:
133+
fnc = llm.utils.build_strict_openai_schema(tool)
134+
function_data = fnc["function"]
135+
schemas.append(
136+
{
137+
"name": function_data["name"],
138+
"description": function_data.get("description") or "",
139+
"input_schema": function_data["parameters"],
140+
"strict": True,
141+
}
142+
)
143+
else:
144+
fnc = llm.utils.build_legacy_openai_schema(tool, internally_tagged=True)
145+
schemas.append(
146+
{
147+
"name": fnc["name"],
148+
"description": fnc["description"] or "",
149+
"input_schema": fnc["parameters"],
150+
}
151+
)
140152
elif isinstance(tool, llm.RawFunctionTool):
141153
info = tool.info
142154
schemas.append(

livekit-agents/livekit/agents/llm/mcp.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
from abc import ABC, abstractmethod
88
from collections.abc import Awaitable, Callable
9-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1010
from dataclasses import dataclass
1111
from datetime import timedelta
1212
from pathlib import Path
@@ -16,6 +16,7 @@
1616
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1717
from typing_extensions import Self
1818

19+
from ..log import logger
1920
from .tool_context import Toolset
2021

2122
try:
@@ -79,7 +80,6 @@ def __init__(
7980
tool_result_resolver: MCPToolResultResolver | None = None,
8081
) -> None:
8182
self._client: ClientSession | None = None
82-
self._exit_stack: AsyncExitStack = AsyncExitStack()
8383
self._read_timeout = client_session_timeout_seconds
8484
self._tool_result_resolver: MCPToolResultResolver = (
8585
tool_result_resolver or _default_tool_result_resolver
@@ -88,6 +88,10 @@ def __init__(
8888
self._cache_dirty = True
8989
self._lk_tools: list[MCPTool] | None = None
9090

91+
self._client_task: asyncio.Task[None] | None = None
92+
self._closing_ev = asyncio.Event()
93+
self._ready_fut: asyncio.Future[None] | None = None
94+
9195
@property
9296
def initialized(self) -> bool:
9397
return self._client is not None
@@ -96,22 +100,45 @@ def invalidate_cache(self) -> None:
96100
self._cache_dirty = True
97101

98102
async def initialize(self) -> None:
103+
if self._client_task and not self._client_task.done():
104+
logger.warning("MCPServer is already initializing")
105+
if self._ready_fut:
106+
await self._ready_fut
107+
return
108+
109+
self._ready_fut = ready_fut = asyncio.Future[None]()
110+
self._client_task = asyncio.create_task(
111+
self._run_client(ready_fut), name=f"{type(self).__name__}._run_client"
112+
)
113+
await ready_fut
114+
115+
async def _run_client(self, ready_fut: asyncio.Future[None]) -> None:
99116
try:
100-
streams = await self._exit_stack.enter_async_context(self.client_streams())
101-
receive_stream, send_stream = streams[0], streams[1]
102-
self._client = await self._exit_stack.enter_async_context(
103-
ClientSession(
117+
async with self.client_streams() as streams:
118+
receive_stream, send_stream = streams[0], streams[1]
119+
async with ClientSession(
104120
receive_stream,
105121
send_stream,
106122
read_timeout_seconds=timedelta(seconds=self._read_timeout)
107123
if self._read_timeout
108124
else None,
109-
)
110-
)
111-
await self._client.initialize() # type: ignore[union-attr]
112-
except Exception:
113-
await self.aclose()
114-
raise
125+
) as client:
126+
await client.initialize()
127+
self._client = client
128+
ready_fut.set_result(None)
129+
130+
await self._closing_ev.wait()
131+
except BaseException as e:
132+
if not ready_fut.done():
133+
ready_fut.set_exception(e) # raising from `await initialize()`
134+
else:
135+
if isinstance(e, Exception):
136+
logger.exception("MCP client connection failed with unexpected error")
137+
raise
138+
finally:
139+
self._client = None
140+
self._lk_tools = None
141+
self._closing_ev.clear()
115142

116143
async def list_tools(self) -> list[MCPTool]:
117144
if self._client is None:
@@ -171,11 +198,13 @@ async def _tool_called(raw_arguments: dict[str, Any]) -> Any:
171198
return function_tool(_tool_called, raw_schema=raw_schema)
172199

173200
async def aclose(self) -> None:
201+
self._closing_ev.set()
174202
try:
175-
await self._exit_stack.aclose()
203+
if self._client_task:
204+
await self._client_task
205+
self._client_task = None
176206
finally:
177-
self._client = None
178-
self._lk_tools = None
207+
self._closing_ev.clear()
179208

180209
@abstractmethod
181210
def client_streams(

livekit-agents/livekit/agents/llm/tool_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,9 @@ def parse_function_tools(
514514
def parse_function_tools(self, format: Literal["aws"]) -> list[dict[str, Any]]: ...
515515

516516
@overload
517-
def parse_function_tools(self, format: Literal["anthropic"]) -> list[dict[str, Any]]: ...
517+
def parse_function_tools(
518+
self, format: Literal["anthropic"], *, strict: bool = True
519+
) -> list[dict[str, Any]]: ...
518520

519521
def parse_function_tools(
520522
self,

livekit-agents/livekit/agents/stt/fallback_adapter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,18 @@ def __init__(
6868
StreamAdapter(stt=t, vad=vad) if not t.capabilities.streaming else t for t in stt
6969
]
7070

71+
# Use the primary STT's aligned_transcript if all providers support it, since
72+
# the SDK only checks truthiness, not the specific granularity.
73+
aligned_transcript: Literal["word", "chunk", False] = False
74+
if all(t.capabilities.aligned_transcript for t in stt):
75+
aligned_transcript = stt[0].capabilities.aligned_transcript
76+
7177
super().__init__(
7278
capabilities=STTCapabilities(
7379
streaming=True,
7480
interim_results=all(t.capabilities.interim_results for t in stt),
7581
diarization=all(t.capabilities.diarization for t in stt),
82+
aligned_transcript=aligned_transcript,
7683
)
7784
)
7885

livekit-agents/livekit/agents/utils/deprecation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,22 @@
44
import inspect
55
from collections import defaultdict
66
from collections.abc import Callable
7-
from typing import ParamSpec, TypeVar
7+
from typing import ParamSpec, TypeVar, cast
88

99
from ..log import logger
1010
from ..types import NOT_GIVEN
1111
from .misc import is_given
1212

1313
_P = ParamSpec("_P")
1414
_R = TypeVar("_R")
15+
_F = TypeVar("_F", bound=Callable)
1516

1617

1718
def deprecate_params(
1819
mapping: dict[str, str],
1920
*,
2021
target_version: str | None = None,
21-
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
22+
) -> Callable[[_F], _F]:
2223
"""
2324
Args:
2425
mapping: {old_param: suggestion}
@@ -59,4 +60,4 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
5960

6061
return wrapper
6162

62-
return decorator
63+
return cast(Callable[[_F], _F], decorator)

livekit-agents/livekit/agents/voice/agent_session.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ def __call__(self, frame: rtc.VideoFrame, session: AgentSession) -> bool:
191191

192192

193193
class AgentSession(rtc.EventEmitter[EventTypes], Generic[Userdata_T]):
194+
@deprecate_params(
195+
{
196+
"min_endpointing_delay": "Use turn_handling=TurnHandlingOptions(...) instead",
197+
"max_endpointing_delay": "Use turn_handling=TurnHandlingOptions(...) instead",
198+
"false_interruption_timeout": "Use turn_handling=TurnHandlingOptions(...) instead",
199+
"resume_false_interruption": "Use turn_handling=TurnHandlingOptions(...) instead",
200+
"allow_interruptions": "Use turn_handling=TurnHandlingOptions(...) instead",
201+
"discard_audio_if_uninterruptible": "Use turn_handling=TurnHandlingOptions(...) instead",
202+
"min_interruption_duration": "Use turn_handling=TurnHandlingOptions(...) instead",
203+
"min_interruption_words": "Use turn_handling=TurnHandlingOptions(...) instead",
204+
"turn_detection": "Use turn_handling=TurnHandlingOptions(...) instead",
205+
"agent_false_interruption_timeout": "Use turn_handling=TurnHandlingOptions(...) instead",
206+
},
207+
target_version="v2.0",
208+
)
194209
def __init__(
195210
self,
196211
*,
@@ -434,23 +449,6 @@ def __init__(
434449
# ivr activity
435450
self._ivr_activity: IVRActivity | None = None
436451

437-
if not TYPE_CHECKING:
438-
__init__ = deprecate_params(
439-
{
440-
"min_endpointing_delay": "Use turn_handling=TurnHandlingOptions(...) instead",
441-
"max_endpointing_delay": "Use turn_handling=TurnHandlingOptions(...) instead",
442-
"false_interruption_timeout": "Use turn_handling=TurnHandlingOptions(...) instead",
443-
"resume_false_interruption": "Use turn_handling=TurnHandlingOptions(...) instead",
444-
"allow_interruptions": "Use turn_handling=TurnHandlingOptions(...) instead",
445-
"discard_audio_if_uninterruptible": "Use turn_handling=TurnHandlingOptions(...) instead",
446-
"min_interruption_duration": "Use turn_handling=TurnHandlingOptions(...) instead",
447-
"min_interruption_words": "Use turn_handling=TurnHandlingOptions(...) instead",
448-
"turn_detection": "Use turn_handling=TurnHandlingOptions(...) instead",
449-
"agent_false_interruption_timeout": "Use turn_handling=TurnHandlingOptions(...) instead",
450-
},
451-
target_version="v2.0",
452-
)(__init__)
453-
454452
def on(self, event: EventTypes, callback: Callable | None = None) -> Callable:
455453
if event == "metrics_collected" and callback is not None:
456454
logger.warning(
@@ -990,31 +988,57 @@ async def aclose(self) -> None:
990988
def update_options(
991989
self,
992990
*,
991+
endpointing_opts: NotGivenOr[EndpointingOptions] = NOT_GIVEN,
992+
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
993+
# deprecated
993994
min_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
994995
max_endpointing_delay: NotGivenOr[float] = NOT_GIVEN,
995-
turn_detection: NotGivenOr[TurnDetectionMode | None] = NOT_GIVEN,
996996
) -> None:
997997
"""
998998
Update the options for the agent session.
999999
10001000
Args:
1001-
min_endpointing_delay (NotGivenOr[float], optional): The minimum endpointing delay.
1002-
max_endpointing_delay (NotGivenOr[float], optional): The maximum endpointing delay.
1001+
endpointing_opts (NotGivenOr[EndpointingOptions], optional): Endpointing options.
10031002
turn_detection (NotGivenOr[TurnDetectionMode | None], optional): Strategy for deciding
10041003
when the user has finished speaking. ``None`` reverts to automatic selection.
1004+
min_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
1005+
max_endpointing_delay: Deprecated, use ``endpointing_opts`` instead.
10051006
"""
1006-
if is_given(min_endpointing_delay):
1007-
self._opts.endpointing["min_delay"] = min_endpointing_delay
1008-
if is_given(max_endpointing_delay):
1009-
self._opts.endpointing["max_delay"] = max_endpointing_delay
1007+
if is_given(min_endpointing_delay) or is_given(max_endpointing_delay):
1008+
logger.warning(
1009+
"min_endpointing_delay and max_endpointing_delay are deprecated, "
1010+
"use endpointing_opts instead"
1011+
)
1012+
endpointing_opts = EndpointingOptions(
1013+
mode=self._opts.endpointing["mode"],
1014+
min_delay=(
1015+
min_endpointing_delay
1016+
if is_given(min_endpointing_delay)
1017+
else self._opts.endpointing["min_delay"]
1018+
),
1019+
max_delay=(
1020+
max_endpointing_delay
1021+
if is_given(max_endpointing_delay)
1022+
else self._opts.endpointing["max_delay"]
1023+
),
1024+
)
1025+
1026+
if is_given(endpointing_opts):
1027+
if (mode := endpointing_opts.get("mode")) is not None:
1028+
self._opts.endpointing["mode"] = mode
1029+
if (min_delay := endpointing_opts.get("min_delay")) is not None:
1030+
self._opts.endpointing["min_delay"] = min_delay
1031+
if (max_delay := endpointing_opts.get("max_delay")) is not None:
1032+
self._opts.endpointing["max_delay"] = max_delay
10101033

10111034
if is_given(turn_detection):
10121035
self._turn_detection = turn_detection
10131036

10141037
if self._activity is not None:
10151038
self._activity.update_options(
1016-
min_endpointing_delay=min_endpointing_delay,
1017-
max_endpointing_delay=max_endpointing_delay,
1039+
endpointing_opts=(
1040+
self._opts.endpointing if is_given(endpointing_opts) else NOT_GIVEN
1041+
),
10181042
turn_detection=turn_detection,
10191043
)
10201044

livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class _LLMOptions:
5656
caching: NotGivenOr[Literal["ephemeral"]]
5757
top_k: NotGivenOr[int]
5858
max_tokens: NotGivenOr[int]
59+
strict_tool_schema: bool
5960
"""If set to "ephemeral", the system prompt, tools, and chat history will be cached."""
6061

6162

@@ -74,6 +75,7 @@ def __init__(
7475
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
7576
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
7677
caching: NotGivenOr[Literal["ephemeral"]] = NOT_GIVEN,
78+
_strict_tool_schema: bool = True,
7779
) -> None:
7880
"""
7981
Create a new instance of Anthropic LLM.
@@ -103,6 +105,7 @@ def __init__(
103105
caching=caching,
104106
top_k=top_k,
105107
max_tokens=max_tokens,
108+
strict_tool_schema=_strict_tool_schema,
106109
)
107110
anthropic_api_key = api_key if is_given(api_key) else os.environ.get("ANTHROPIC_API_KEY")
108111
if not anthropic_api_key:
@@ -164,7 +167,9 @@ def chat(
164167
from .tools import AnthropicTool
165168

166169
tool_ctx = llm.ToolContext(tools)
167-
tool_schemas = tool_ctx.parse_function_tools("anthropic")
170+
tool_schemas = tool_ctx.parse_function_tools(
171+
"anthropic", strict=self._opts.strict_tool_schema
172+
)
168173

169174
for tool in tool_ctx.provider_tools:
170175
if isinstance(tool, AnthropicTool):

livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class STTOptions:
6262
vad_threshold: NotGivenOr[float] = NOT_GIVEN
6363
speaker_labels: NotGivenOr[bool] = NOT_GIVEN
6464
max_speakers: NotGivenOr[int] = NOT_GIVEN
65+
domain: NotGivenOr[str] = NOT_GIVEN
6566

6667

6768
class STT(stt.STT):
@@ -87,6 +88,7 @@ def __init__(
8788
vad_threshold: NotGivenOr[float] = NOT_GIVEN,
8889
speaker_labels: NotGivenOr[bool] = NOT_GIVEN,
8990
max_speakers: NotGivenOr[int] = NOT_GIVEN,
91+
domain: NotGivenOr[str] = NOT_GIVEN,
9092
http_session: aiohttp.ClientSession | None = None,
9193
buffer_size_seconds: float = 0.05,
9294
base_url: str = "wss://streaming.assemblyai.com",
@@ -161,6 +163,7 @@ def __init__(
161163
vad_threshold=vad_threshold,
162164
speaker_labels=speaker_labels,
163165
max_speakers=max_speakers,
166+
domain=domain,
164167
)
165168
self._session = http_session
166169
self._streams = weakref.WeakSet[SpeechStream]()
@@ -483,6 +486,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
483486
if is_given(self._opts.speaker_labels)
484487
else None,
485488
"max_speakers": self._opts.max_speakers if is_given(self._opts.max_speakers) else None,
489+
"domain": self._opts.domain if is_given(self._opts.domain) else None,
486490
}
487491

488492
headers = {

0 commit comments

Comments
 (0)