Skip to content

Commit 476271b

Browse files
jeradfdavidzhao
andauthored
multilingual turn detector (#1736)
Co-authored-by: David Zhao <[email protected]>
1 parent cdb85d5 commit 476271b

File tree

10 files changed

+150
-53
lines changed

10 files changed

+150
-53
lines changed

.changeset/many-emus-allow.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"livekit-plugins-turn-detector": patch
3+
---
4+
5+
added a multilingual turn detector option

.changeset/yellow-ways-dance.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"livekit-plugins-deepgram": patch
3+
---
4+
5+
support multilingual with Nova-3 model

examples/voice-pipeline-agent/turn_detector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
metrics,
1212
)
1313
from livekit.agents.pipeline import VoicePipelineAgent
14-
from livekit.plugins import deepgram, openai, silero, turn_detector
14+
from livekit.plugins import deepgram, openai, silero
15+
from livekit.plugins.turn_detector.multilingual import MultilingualModel
1516

1617
load_dotenv()
1718
logger = logging.getLogger("voice-assistant")
@@ -42,11 +43,11 @@ async def entrypoint(ctx: JobContext):
4243

4344
agent = VoicePipelineAgent(
4445
vad=ctx.proc.userdata["vad"],
45-
stt=deepgram.STT(),
46+
stt=deepgram.STT(model="nova-3", language="multi"),
4647
llm=openai.LLM(model="gpt-4o-mini"),
4748
tts=openai.TTS(),
4849
chat_ctx=initial_ctx,
49-
turn_detector=turn_detector.EOUModel(),
50+
turn_detector=MultilingualModel(),
5051
)
5152

5253
agent.start(ctx.room, participant)

livekit-agents/livekit/agents/pipeline/pipeline_agent.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class AgentTranscriptionOptions:
164164
class _TurnDetector(Protocol):
165165
# When endpoint probability is below this threshold we think the user is not finished speaking
166166
# so we will use a long delay
167-
def unlikely_threshold(self) -> float: ...
167+
def unlikely_threshold(self, language: str | None) -> float: ...
168168
def supports_language(self, language: str | None) -> bool: ...
169169
async def predict_end_of_turn(self, chat_ctx: ChatContext) -> float: ...
170170

@@ -1314,6 +1314,10 @@ def _compute_delay(self) -> float | None:
13141314

13151315
def on_human_final_transcript(self, transcript: str, language: str | None) -> None:
13161316
self._last_final_transcript += " " + transcript.strip() # type: ignore
1317+
logger.debug(
1318+
"last language updated",
1319+
extra={"from": self._last_language, "to": language},
1320+
)
13171321
self._last_language = language
13181322
self._last_recv_transcript_time = time.perf_counter()
13191323

@@ -1355,21 +1359,28 @@ def _run(self, delay: float) -> None:
13551359
@utils.log_exceptions(logger=logger)
13561360
async def _run_task(chat_ctx: ChatContext, delay: float) -> None:
13571361
use_turn_detector = self._last_final_transcript and not self._speaking
1358-
if (
1359-
use_turn_detector
1360-
and self._turn_detector is not None
1361-
and self._turn_detector.supports_language(self._last_language)
1362-
):
1363-
start_time = time.perf_counter()
1364-
try:
1365-
eot_prob = await self._turn_detector.predict_end_of_turn(chat_ctx)
1366-
unlikely_threshold = self._turn_detector.unlikely_threshold()
1367-
elasped = time.perf_counter() - start_time
1368-
if eot_prob < unlikely_threshold:
1369-
delay = self._max_endpointing_delay
1370-
delay = max(0, delay - elasped)
1371-
except (TimeoutError, AssertionError):
1372-
pass # inference process is unresponsive
1362+
1363+
if use_turn_detector and self._turn_detector is not None:
1364+
if not self._turn_detector.supports_language(self._last_language):
1365+
logger.debug(
1366+
"turn detector does not support language",
1367+
extra={"language": self._last_language},
1368+
)
1369+
else:
1370+
start_time = time.perf_counter()
1371+
try:
1372+
eot_prob = await self._turn_detector.predict_end_of_turn(
1373+
chat_ctx
1374+
)
1375+
unlikely_threshold = self._turn_detector.unlikely_threshold(
1376+
self._last_language
1377+
)
1378+
elasped = time.perf_counter() - start_time
1379+
if eot_prob < unlikely_threshold:
1380+
delay = self._max_endpointing_delay
1381+
delay = max(0, delay - elasped)
1382+
except (TimeoutError, AssertionError):
1383+
pass # inference process is unresponsive
13731384

13741385
await asyncio.sleep(delay)
13751386

livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -698,16 +698,19 @@ def live_transcription_to_speech_data(
698698
) -> List[stt.SpeechData]:
699699
dg_alts = data["channel"]["alternatives"]
700700

701-
return [
702-
stt.SpeechData(
701+
speech_data = []
702+
for alt in dg_alts:
703+
sd = stt.SpeechData(
703704
language=language,
704705
start_time=alt["words"][0]["start"] if alt["words"] else 0,
705706
end_time=alt["words"][-1]["end"] if alt["words"] else 0,
706707
confidence=alt["confidence"],
707708
text=alt["transcript"],
708709
)
709-
for alt in dg_alts
710-
]
710+
if language == "multi" and "languages" in alt:
711+
sd.language = alt["languages"][0] # TODO: handle multiple languages
712+
speech_data.append(sd)
713+
return speech_data
711714

712715

713716
def prerecorded_transcription_to_speech_event(
@@ -774,7 +777,6 @@ def _validate_model(
774777
"nova-2-drivethru",
775778
"nova-2-automotive",
776779
# nova-3 will support more languages, but english-only for now
777-
"nova-3",
778780
"nova-3-general",
779781
}
780782
if language not in ("en-US", "en") and model in en_only_models:

livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# limitations under the License.
1414

1515
from livekit.agents import Plugin
16-
from livekit.agents.inference_runner import _InferenceRunner
1716

18-
from .eou import EOUModel, _EUORunner
17+
from .english import EnglishModel
1918
from .log import logger
2019
from .version import __version__
2120

22-
__all__ = ["EOUModel", "__version__"]
21+
__all__ = ["EOUModel", "english", "multilingual", "__version__"]
2322

2423

2524
class EOUPlugin(Plugin):
@@ -29,13 +28,16 @@ def __init__(self):
2928
def download_files(self) -> None:
3029
from transformers import AutoTokenizer
3130

32-
from .eou import HG_MODEL, MODEL_REVISION, ONNX_FILENAME, _download_from_hf_hub
31+
from .base import _download_from_hf_hub
32+
from .models import HG_MODEL, MODEL_REVISIONS, ONNX_FILENAME
3333

34-
AutoTokenizer.from_pretrained(HG_MODEL, revision=MODEL_REVISION)
35-
_download_from_hf_hub(
36-
HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=MODEL_REVISION
37-
)
34+
for revision in MODEL_REVISIONS.values():
35+
AutoTokenizer.from_pretrained(HG_MODEL, revision=revision)
36+
_download_from_hf_hub(
37+
HG_MODEL, ONNX_FILENAME, subfolder="onnx", revision=revision
38+
)
39+
_download_from_hf_hub(HG_MODEL, "languages.json", revision=revision)
3840

3941

4042
Plugin.register_plugin(EOUPlugin())
41-
_InferenceRunner.register_runner(_EUORunner)
43+
EOUModel = EnglishModel

livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py renamed to livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/base.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
import asyncio
44
import json
55
import time
6+
from abc import ABC, abstractmethod
67

78
from livekit.agents import llm
89
from livekit.agents.inference_runner import _InferenceRunner
910
from livekit.agents.ipc.inference_executor import InferenceExecutor
1011
from livekit.agents.job import get_current_job_context
1112

1213
from .log import logger
14+
from .models import HG_MODEL, MODEL_REVISIONS, ONNX_FILENAME, EOUModelType
1315

14-
HG_MODEL = "livekit/turn-detector"
15-
ONNX_FILENAME = "model_q8.onnx"
16-
MODEL_REVISION = "v1.2.1"
1716
MAX_HISTORY_TOKENS = 512
1817
MAX_HISTORY_TURNS = 6
1918

@@ -25,8 +24,10 @@ def _download_from_hf_hub(repo_id, filename, **kwargs):
2524
return local_path
2625

2726

28-
class _EUORunner(_InferenceRunner):
29-
INFERENCE_METHOD = "lk_end_of_utterance"
27+
class _EUORunnerBase(_InferenceRunner):
28+
def __init__(self, model_type: EOUModelType):
29+
super().__init__()
30+
self._model_revision = MODEL_REVISIONS[model_type]
3031

3132
def _format_chat_ctx(self, chat_ctx: dict):
3233
new_chat_ctx = []
@@ -60,7 +61,7 @@ def initialize(self) -> None:
6061
HG_MODEL,
6162
ONNX_FILENAME,
6263
subfolder="onnx",
63-
revision=MODEL_REVISION,
64+
revision=self._model_revision,
6465
local_files_only=True,
6566
)
6667
self._session = ort.InferenceSession(
@@ -69,19 +70,20 @@ def initialize(self) -> None:
6970

7071
self._tokenizer = AutoTokenizer.from_pretrained(
7172
HG_MODEL,
72-
revision=MODEL_REVISION,
73+
revision=self._model_revision,
7374
local_files_only=True,
7475
truncation_side="left",
7576
)
77+
7678
except (errors.LocalEntryNotFoundError, OSError):
7779
logger.error(
7880
(
79-
f"Could not find model {HG_MODEL}. Make sure you have downloaded the model before running the agent. "
81+
f"Could not find model {HG_MODEL} with revision {self._model_revision}. Make sure you have downloaded the model before running the agent. "
8082
"Use `python3 your_agent.py download-files` to download the models."
8183
)
8284
)
8385
raise RuntimeError(
84-
f"livekit-plugins-turn-detector initialization failed. Could not find model {HG_MODEL}."
86+
f"livekit-plugins-turn-detector initialization failed. Could not find model {HG_MODEL} with revision {self._model_revision}."
8587
) from None
8688

8789
def run(self, data: bytes) -> bytes | None:
@@ -116,26 +118,44 @@ def run(self, data: bytes) -> bytes | None:
116118
return json.dumps(data).encode()
117119

118120

119-
class EOUModel:
121+
class EOUModelBase(ABC):
120122
def __init__(
121123
self,
124+
model_type: EOUModelType = "en", # default to smaller, english-only model
122125
inference_executor: InferenceExecutor | None = None,
123-
unlikely_threshold: float = 0.0289,
124126
) -> None:
127+
self._model_type = model_type
125128
self._executor = (
126129
inference_executor or get_current_job_context().inference_executor
127130
)
128-
self._unlikely_threshold = unlikely_threshold
129131

130-
def unlikely_threshold(self) -> float:
131-
return self._unlikely_threshold
132+
config_fname = _download_from_hf_hub(
133+
HG_MODEL,
134+
"languages.json",
135+
revision=MODEL_REVISIONS[self._model_type],
136+
local_files_only=True,
137+
)
138+
with open(config_fname, "r") as f:
139+
self._languages = json.load(f)
132140

133-
def supports_language(self, language: str | None) -> bool:
141+
@abstractmethod
142+
def _inference_method(self): ...
143+
144+
def unlikely_threshold(self, language: str | None) -> float | None:
134145
if language is None:
135-
return False
136-
parts = language.lower().split("-")
137-
# certain models use language codes (DG, AssemblyAI), others use full names (like OAI)
138-
return parts[0] == "en" or parts[0] == "english"
146+
return None
147+
lang = language.lower()
148+
if lang in self._languages:
149+
return self._languages[lang]["threshold"]
150+
if "-" in lang:
151+
part = lang.split("-")[0]
152+
if part in self._languages:
153+
return self._languages[part]["threshold"]
154+
logger.warning(f"Language {language} not supported by EOU model")
155+
return None
156+
157+
def supports_language(self, language: str | None) -> bool:
158+
return self.unlikely_threshold(language) is not None
139159

140160
async def predict_eou(self, chat_ctx: llm.ChatContext) -> float:
141161
return await self.predict_end_of_turn(chat_ctx)
@@ -173,7 +193,7 @@ async def predict_end_of_turn(
173193
json_data = json.dumps({"chat_ctx": messages}).encode()
174194

175195
result = await asyncio.wait_for(
176-
self._executor.do_inference(_EUORunner.INFERENCE_METHOD, json_data),
196+
self._executor.do_inference(self._inference_method(), json_data),
177197
timeout=timeout,
178198
)
179199

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from livekit.agents.inference_runner import _InferenceRunner
2+
3+
from .base import EOUModelBase, _EUORunnerBase
4+
5+
6+
class _EUORunnerEn(_EUORunnerBase):
7+
INFERENCE_METHOD = "lk_end_of_utterance_en"
8+
9+
def __init__(self):
10+
super().__init__("en")
11+
12+
13+
class EnglishModel(EOUModelBase):
14+
def __init__(self):
15+
super().__init__(model_type="en")
16+
17+
def _inference_method(self) -> str:
18+
return _EUORunnerEn.INFERENCE_METHOD
19+
20+
21+
_InferenceRunner.register_runner(_EUORunnerEn)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Literal
2+
3+
EOUModelType = Literal["en", "multilingual"]
4+
MODEL_REVISIONS: dict[EOUModelType, str] = {
5+
"en": "v1.2.2-en",
6+
"multilingual": "v0.1.0-intl",
7+
}
8+
HG_MODEL = "livekit/turn-detector"
9+
ONNX_FILENAME = "model_q8.onnx"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from livekit.agents.inference_runner import _InferenceRunner
2+
3+
from .base import EOUModelBase, _EUORunnerBase
4+
5+
6+
class _EUORunnerMultilingual(_EUORunnerBase):
7+
INFERENCE_METHOD = "lk_end_of_utterance_multilingual"
8+
9+
def __init__(self):
10+
super().__init__("multilingual")
11+
12+
13+
class MultilingualModel(EOUModelBase):
14+
def __init__(self):
15+
super().__init__(model_type="multilingual")
16+
17+
def _inference_method(self) -> str:
18+
return _EUORunnerMultilingual.INFERENCE_METHOD
19+
20+
21+
_InferenceRunner.register_runner(_EUORunnerMultilingual)

0 commit comments

Comments
 (0)