Skip to content

Commit a290c93

Browse files
authored
Merge pull request #41 from TEN-framework/dev/tts_feature
Dev/tts feature
2 parents 163ae9b + 2705cb6 commit a290c93

File tree

1 file changed

+87
-2
lines changed

1 file changed

+87
-2
lines changed

interface/ten_ai_base/tts2.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import asyncio
88
import json
99
import traceback
10+
import uuid
1011

1112
from .helper import AsyncQueue
1213
from .message import ModuleError, ModuleMetricKey, ModuleMetrics, ModuleType, TTSAudioEndReason
@@ -46,6 +47,18 @@ def __init__(self, name: str):
4647
self.leftover_bytes = b""
4748
self.session_id = None
4849

50+
# billing metrics every 5 seconds
51+
self.output_characters = 0
52+
self.input_characters = 0
53+
self.recv_audio_duration = 0
54+
self.recv_audio_chunks = b""
55+
#billing metricstotal
56+
self.total_output_characters = 0
57+
self.total_input_characters = 0
58+
self.total_recv_audio_duration = 0
59+
self.total_recv_audio_chunks = b""
60+
self.timer_task = None
61+
4962
async def on_init(self, ten_env: AsyncTenEnv) -> None:
5063
await super().on_init(ten_env)
5164
self.ten_env = ten_env
@@ -55,13 +68,16 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:
5568
if self.loop_task is None:
5669
self.loop = asyncio.get_event_loop()
5770
self.loop_task = self.loop.create_task(self._process_input_queue(ten_env))
71+
self.timer_task = asyncio.create_task(self.timer_logger(ten_env))
5872

5973
async def on_stop(self, ten_env: AsyncTenEnv) -> None:
6074
await super().on_stop(ten_env)
6175
await self._flush_input_items()
6276
if self.loop_task:
6377
self.loop_task.cancel()
6478
await self.input_queue.put(None) # Signal the loop to stop processing
79+
if self.timer_task:
80+
self.timer_task.cancel()
6581

6682
async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
6783
await super().on_deinit(ten_env)
@@ -199,7 +215,6 @@ async def send_tts_audio_data(
199215
buff = f.lock_buf()
200216
buff[:] = combined_data
201217
f.unlock_buf(buff)
202-
self.ten_env.log_debug(f"send audio frame, timestamp: {timestamp}, length: {len(combined_data)}")
203218
await self.ten_env.send_audio_frame(f)
204219
except Exception as e:
205220
self.ten_env.log_error(
@@ -295,6 +310,70 @@ async def send_tts_error(
295310

296311
await self.ten_env.send_data(error_data)
297312

313+
# Start timer task to print log every 5 seconds
314+
async def timer_logger(self, ten_env: AsyncTenEnv):
315+
while True:
316+
try:
317+
await asyncio.sleep(5)
318+
self.billing_metrics_calculate_duration()
319+
data = Data.create("metrics")
320+
metrics = ModuleMetrics(
321+
id=self.get_uuid(),
322+
module=ModuleType.TTS,
323+
vendor=self.vendor(),
324+
metrics={
325+
"output_characters": self.output_characters,
326+
"input_characters": self.input_characters,
327+
"recv_audio_duration": self.recv_audio_duration,
328+
"total_output_characters": self.total_output_characters,
329+
"total_input_characters": self.total_input_characters,
330+
"total_recv_audio_duration": self.total_recv_audio_duration,
331+
}
332+
)
333+
ten_env.log_debug(f"billing_metrics: {metrics}")
334+
data.set_property_from_json(None, metrics.model_dump_json())
335+
await ten_env.send_data(data)
336+
self.billing_metrics_reset()
337+
338+
except asyncio.CancelledError:
339+
break
340+
341+
def billing_metrics_calculate_duration(self) -> None:
342+
self.recv_audio_duration = (float(len(self.recv_audio_chunks)) / self.synthesize_audio_channels() * 1000 / self.synthesize_audio_sample_width()) / self.synthesize_audio_sample_rate()
343+
self.total_recv_audio_duration = (float(len(self.total_recv_audio_chunks)) / self.synthesize_audio_channels() * 1000 / self.synthesize_audio_sample_width()) / self.synthesize_audio_sample_rate()
344+
345+
def billing_metrics_add_output_characters(self, characters: int) -> None:
346+
self.output_characters += characters
347+
self.total_output_characters += characters
348+
349+
def billing_metrics_add_input_characters(self, characters: int) -> None:
350+
self.input_characters += characters
351+
self.total_input_characters += characters
352+
353+
def billing_metrics_add_recv_audio_chunks(self, chunks: bytes) -> None:
354+
self.total_recv_audio_chunks += chunks
355+
self.recv_audio_chunks += chunks
356+
357+
def billing_metrics_reset(self) -> None:
358+
self.output_characters = 0
359+
self.input_characters = 0
360+
self.recv_audio_duration = 0
361+
self.recv_audio_chunks = b""
362+
363+
async def metrics_connect_delay(self, connect_delay_ms: int):
364+
data = Data.create("metrics")
365+
metrics = ModuleMetrics(
366+
id=self.get_uuid(),
367+
module=ModuleType.TTS,
368+
vendor=self.vendor(),
369+
metrics={
370+
"connect_delay": connect_delay_ms,
371+
}
372+
)
373+
self.ten_env.log_debug(f"metrics_connect_delay: {metrics}")
374+
data.set_property_from_json(None, metrics.model_dump_json())
375+
await self.ten_env.send_data(data)
376+
298377

299378
@abstractmethod
300379
def vendor(self) -> str:
@@ -339,4 +418,10 @@ def synthesize_audio_sample_width(self) -> int:
339418
Get the sample width in bytes for input audio.
340419
Default is 2 (16-bit PCM).
341420
"""
342-
return 2
421+
return 2
422+
423+
def get_uuid(self) -> str:
424+
"""
425+
Get a unique identifier
426+
"""
427+
return uuid.uuid4().hex

0 commit comments

Comments
 (0)