Skip to content

Commit 26874a5

Browse files
authored
[voice agent] Fix RTVI missing bot message (#15068)
* fix RTVI missing bot message, fix diar not passing VAD frames Signed-off-by: stevehuang52 <[email protected]> * revert change to diar Signed-off-by: stevehuang52 <[email protected]> --------- Signed-off-by: stevehuang52 <[email protected]>
1 parent 10327b8 commit 26874a5

File tree

3 files changed

+25
-17
lines changed

3 files changed

+25
-17
lines changed

examples/voice_agent/server/bot_websocket_server.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
import sys
2121

2222
from loguru import logger
23-
from omegaconf import OmegaConf
2423

25-
from pipecat.audio.vad.silero import SileroVADAnalyzer, VADParams
24+
from pipecat.audio.vad.silero import SileroVADAnalyzer
2625
from pipecat.frames.frames import EndTaskFrame
2726
from pipecat.pipeline.pipeline import Pipeline
2827
from pipecat.pipeline.runner import PipelineRunner
@@ -112,7 +111,10 @@ def signal_handler(signum, frame):
112111
shutdown_event.set()
113112

114113

115-
async def run_bot_websocket_server():
114+
async def run_bot_websocket_server(host: str = "0.0.0.0", port: int = 8765):
115+
logger.info(f"Starting websocket server on {host}:{port}")
116+
logger.info(f"Server configured to run indefinitely with no timeouts, use Ctrl+C to quit.")
117+
116118
# Set up signal handlers for graceful shutdown
117119
signal.signal(signal.SIGINT, signal_handler)
118120
signal.signal(signal.SIGTERM, signal_handler)
@@ -147,8 +149,8 @@ async def run_bot_websocket_server():
147149
is None, # if backchannel phrases are disabled, we can use VAD to interrupt the bot immediately
148150
audio_out_10ms_chunks=TRANSPORT_AUDIO_OUT_10MS_CHUNKS,
149151
),
150-
host="0.0.0.0", # Bind to all interfaces
151-
port=8765,
152+
host=host,
153+
port=port,
152154
)
153155

154156
logger.info("Initializing STT service...")
@@ -279,7 +281,7 @@ async def reset_context_handler(rtvi_processor: RTVIProcessor, service: str, arg
279281

280282
pipeline = Pipeline(pipeline)
281283

282-
rtvi_text_aggregator = SimpleSegmentedTextAggregator("\n?!.", min_sentence_length=5)
284+
rtvi_text_aggregator = SimpleSegmentedTextAggregator(punctuation_marks=".!?\n")
283285
task = PipelineTask(
284286
pipeline,
285287
params=PipelineParams(

examples/voice_agent/server/server_configs/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ vad:
1515

1616
stt:
1717
type: nemo # choices in ['nemo'] currently only NeMo is supported
18-
model: "stt_en_fastconformer_hybrid_large_streaming_80ms"
19-
# model: "nvidia/parakeet_realtime_eou_120m-v1"
18+
# model: "stt_en_fastconformer_hybrid_large_streaming_80ms"
19+
model: "nvidia/parakeet_realtime_eou_120m-v1"
2020
model_config: "./server_configs/stt_configs/nemo_cache_aware_streaming.yaml"
2121
device: "cuda"
2222

nemo/agents/voice_agent/pipecat/utils/text/simple_text_aggregator.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def find_last_period_index(text: str) -> int:
101101
class SimpleSegmentedTextAggregator(SimpleTextAggregator):
102102
def __init__(
103103
self,
104-
punctuation_marks: str | list[str] = ".,!?;:",
104+
punctuation_marks: str | list[str] = ".,!?;:\n",
105105
ignore_marks: str | list[str] = "*",
106106
min_sentence_length: int = 0,
107107
use_legacy_eos_detection: bool = False,
@@ -130,9 +130,8 @@ def __init__(
130130
)
131131
if "." in punctuation_marks:
132132
punctuation_marks.remove(".")
133-
punctuation_marks += [
134-
"."
135-
] # put period at the end of the list to ensure it's the last punctuation mark to be matched
133+
# put period at the end of the list to ensure it's the last punctuation mark to be matched
134+
punctuation_marks += ["."]
136135
self._punctuation_marks = punctuation_marks
137136

138137
def _find_segment_end(self, text: str) -> Optional[int]:
@@ -144,7 +143,12 @@ def _find_segment_end(self, text: str) -> Optional[int]:
144143
Returns:
145144
The index of the end of the segment, or None if the text is too short.
146145
"""
147-
if len(text.strip()) < self._min_sentence_length:
146+
# drop leading whitespace but keep trailing whitespace to
147+
# allow "\n" to trigger the end of the sentence
148+
text_len = len(text)
149+
text = text.lstrip()
150+
offset = text_len - len(text)
151+
if len(text) < self._min_sentence_length:
148152
return None
149153

150154
for punc in self._punctuation_marks:
@@ -153,12 +157,12 @@ def _find_segment_end(self, text: str) -> Optional[int]:
153157
else:
154158
idx = text.find(punc)
155159
if idx != -1:
156-
return idx + 1
160+
# add the offset to the index to account for the leading whitespace
161+
return idx + 1 + offset
157162
return None
158163

159164
async def aggregate(self, text: str) -> Optional[str]:
160165
result: Optional[str] = None
161-
162166
self._text += str(text)
163167

164168
for ignore_mark in self._ignore_marks:
@@ -174,10 +178,12 @@ async def aggregate(self, text: str) -> Optional[str]:
174178
if eos_end_index:
175179
result = self._text[:eos_end_index]
176180
if len(result.strip()) < self._min_sentence_length:
181+
logger.debug(
182+
f"Text is too short, skipping: `{result}`, full text: `{self._text}`, input text: `{text}`"
183+
)
177184
result = None
178-
logger.debug(f"Text is too short, skipping: `{result}`, full text: `{self._text}`")
179185
else:
180-
logger.debug(f"Text Aggregator Result: `{result}`, full text: `{self._text}`")
186+
logger.debug(f"Text Aggregator Result: `{result}`, full text: `{self._text}`, input text: `{text}`")
181187
self._text = self._text[eos_end_index:]
182188

183189
return result

0 commit comments

Comments
 (0)