Skip to content

Commit d074196

Browse files
committed
feat: support dump audio by session
1 parent a127aa7 commit d074196

File tree

3 files changed

+131
-41
lines changed

3 files changed

+131
-41
lines changed

interface/ten_ai_base/asr.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,35 @@
33
# Licensed under the Apache License, Version 2.0.
44
# See the LICENSE file for more information.
55
#
6+
"""ASR base class: connection, buffering, metrics and result helpers."""
67
from abc import abstractmethod
8+
import asyncio
9+
from functools import wraps
10+
import json
11+
import os
712
from typing import Any, final
813
import uuid
914

10-
from .struct import ASRResult
11-
from .types import ASRBufferConfig, ASRBufferConfigModeDiscard, ASRBufferConfigModeKeep
15+
from ten_runtime import (
16+
AsyncExtension,
17+
AsyncTenEnv,
18+
Cmd,
19+
Data,
20+
AudioFrame,
21+
StatusCode,
22+
CmdResult,
23+
)
1224

25+
from .const import (
26+
DATA_IN_ASR_FINALIZE,
27+
DATA_OUT_ASR_FINALIZE_END,
28+
DATA_OUT_METRICS,
29+
PROPERTY_KEY_DUMP,
30+
PROPERTY_KEY_DUMP_PATH,
31+
PROPERTY_KEY_METADATA,
32+
PROPERTY_KEY_SESSION_ID,
33+
)
34+
from .dumper import Dumper
1335
from .message import (
1436
ModuleError,
1537
ModuleErrorCode,
@@ -18,28 +40,16 @@
1840
ModuleMetrics,
1941
ModuleType,
2042
)
43+
from .struct import ASRResult
2144
from .timeline import AudioTimeline
22-
from .const import (
23-
DATA_IN_ASR_FINALIZE,
24-
DATA_OUT_ASR_FINALIZE_END,
25-
DATA_OUT_METRICS,
26-
PROPERTY_KEY_METADATA,
27-
PROPERTY_KEY_SESSION_ID,
28-
)
29-
from ten_runtime import (
30-
AsyncExtension,
31-
AsyncTenEnv,
32-
Cmd,
33-
Data,
34-
AudioFrame,
35-
StatusCode,
36-
CmdResult,
37-
)
38-
import asyncio
39-
import json
45+
from .types import ASRBufferConfig, ASRBufferConfigModeDiscard, ASRBufferConfigModeKeep
46+
4047

48+
class AsyncASRBaseExtension(
49+
AsyncExtension
50+
): # pylint: disable=too-many-instance-attributes, too-many-public-methods
51+
"""Asynchronous base class for ASR modules."""
4152

42-
class AsyncASRBaseExtension(AsyncExtension):
4353
def __init__(self, name: str):
4454
super().__init__(name)
4555

@@ -52,6 +62,7 @@ def __init__(self, name: str):
5262
self.buffered_frames_size = 0
5363
self.audio_frames_queue = asyncio.Queue[AudioFrame]()
5464
self.audio_timeline = AudioTimeline()
65+
self.dumper: Dumper | None = None
5566
self.audio_actual_send_metrics_task: asyncio.Task[None] | None = None
5667
self.uuid = self._get_uuid() # Unique identifier for the current final turn
5768

@@ -68,6 +79,18 @@ async def on_init(self, ten_env: AsyncTenEnv) -> None:
6879
self.ten_env = ten_env
6980
asyncio.create_task(self._audio_frame_consumer())
7081

82+
enable_dump, err = await ten_env.get_property_bool(PROPERTY_KEY_DUMP)
83+
if err:
84+
ten_env.log_info(f"dump not set, disable dump: {err}")
85+
elif enable_dump:
86+
dump_path, err = await ten_env.get_property_string(PROPERTY_KEY_DUMP_PATH)
87+
if err:
88+
ten_env.log_warn(f"dump_path not set, use current directory: {err}")
89+
dump_path = "."
90+
91+
dump_file_path = os.path.join(dump_path, self.dump_file_name())
92+
self.dumper = Dumper(dump_file_path, None)
93+
7194
async def on_start(self, ten_env: AsyncTenEnv) -> None:
7295
ten_env.log_info("on_start")
7396

@@ -127,6 +150,10 @@ def vendor(self) -> str:
127150
"""Get the name of the ASR vendor."""
128151
raise NotImplementedError("This method should be implemented in subclasses.")
129152

153+
def dump_file_name(self) -> str:
154+
"""Return the base dump filename."""
155+
return f"{self.name}_out.pcm"
156+
130157
@abstractmethod
131158
async def start_connection(self) -> None:
132159
"""Start the connection to the ASR service."""
@@ -163,6 +190,26 @@ def input_audio_sample_width(self) -> int:
163190
"""
164191
return 2
165192

193+
# Automatically wrap subclass start_connection to update dumper session first
194+
def __init_subclass__(cls, **kwargs):
195+
super().__init_subclass__(**kwargs)
196+
orig = cls.__dict__.get("start_connection")
197+
if orig is None:
198+
return
199+
200+
# Only wrap coroutine functions
201+
@wraps(orig)
202+
async def wrapped(self: AsyncASRBaseExtension, *args, **kw):
203+
try:
204+
if self.dumper is not None:
205+
await self.dumper.update_session()
206+
except Exception as e: # pylint: disable=broad-exception-caught
207+
if self.ten_env is not None:
208+
self.ten_env.log_error(f"auto update_session failed: {e}")
209+
return await orig(self, *args, **kw)
210+
211+
setattr(cls, "start_connection", wrapped)
212+
166213
def buffer_strategy(self) -> ASRBufferConfig:
167214
"""
168215
Get the buffer strategy for audio frames when not connected
@@ -442,7 +489,7 @@ async def _audio_frame_consumer(self) -> None:
442489
await self._handle_audio_frame(self.ten_env, audio_frame)
443490
except asyncio.CancelledError:
444491
break
445-
except Exception as e:
492+
except Exception as e: # pylint: disable=broad-exception-caught
446493
self.ten_env.log_error(f"Error consuming audio frame: {e}")
447494

448495
async def _send_audio_actual_send_metrics_task(self) -> None:

interface/ten_ai_base/const.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@
3535

3636
PROPERTY_KEY_METADATA = "metadata"
3737
PROPERTY_KEY_SESSION_ID = "session_id"
38+
PROPERTY_KEY_DUMP = "dump"
39+
PROPERTY_KEY_DUMP_PATH = "dump_path"

interface/ten_ai_base/dumper.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,75 @@
33
# Licensed under the Apache License, Version 2.0.
44
# See the LICENSE file for more information.
55
#
6+
"""Simple async dumper for writing bytes to session-rotated files."""
67

8+
import asyncio
9+
from pathlib import Path
710
import aiofiles
8-
import os
911

1012

1113
class Dumper:
12-
def __init__(self, dump_file_path: str):
13-
self.dump_file_path: str = dump_file_path
14+
"""Asynchronous dumper that writes to files with session-based names."""
15+
16+
def __init__(
17+
self,
18+
base_dump_file_path: str,
19+
session_name: str | None = None,
20+
delimiter: str = "_",
21+
):
22+
self.base_dump_file_path: str = base_dump_file_path
23+
self.session_name: str | None = session_name
24+
self.delimiter: str = delimiter
1425
self._file: aiofiles.threadpool.binary.AsyncBufferedIOBase | None = None
26+
self._lock: asyncio.Lock = asyncio.Lock()
27+
28+
@property
29+
def dump_file_path(self) -> str:
30+
base_path = Path(self.base_dump_file_path)
31+
if not self.session_name:
32+
return str(base_path)
33+
34+
stem = base_path.stem
35+
suffix = base_path.suffix
36+
new_name = f"{stem}{self.delimiter}{self.session_name}{suffix}"
37+
return str(base_path.with_name(new_name))
38+
39+
async def close(self) -> None:
40+
async with self._lock:
41+
if self._file:
42+
await self._file.close()
43+
self._file = None
44+
45+
async def update_session(self) -> None:
46+
"""Rotate to a new session file and open it if needed.
1547
16-
async def start(self):
17-
if self._file:
18-
return
48+
The session name is generated from current event-loop time
49+
to ensure uniqueness and ordering. If the generated name equals
50+
the current one, this function is a no-op.
51+
"""
52+
async with self._lock:
53+
# Generate a new session name based on timestamp
54+
current_time = asyncio.get_event_loop().time()
55+
new_session_name = f"{current_time:.6f}"
1956

20-
os.makedirs(os.path.dirname(self.dump_file_path), exist_ok=True)
57+
if new_session_name == self.session_name and self._file is not None:
58+
# Already opened on the same session
59+
return
2160

22-
self._file = await aiofiles.open(self.dump_file_path, mode="wb")
61+
# Close previous file if any
62+
if self._file is not None:
63+
await self._file.close()
64+
self._file = None
2365

24-
async def stop(self):
25-
if self._file:
26-
await self._file.close()
27-
self._file = None
66+
# Update session and open the new file
67+
self.session_name = new_session_name
68+
Path(self.dump_file_path).parent.mkdir(parents=True, exist_ok=True)
69+
self._file = await aiofiles.open(self.dump_file_path, mode="wb")
2870

29-
async def push_bytes(self, data: bytes):
30-
if not self._file:
31-
raise RuntimeError(
32-
"Dumper for {} is not opened. Please start the Dumper first.".format(
33-
self.dump_file_path
71+
async def push_bytes(self, data: bytes) -> int:
72+
async with self._lock:
73+
if not self._file:
74+
raise RuntimeError(
75+
f"Dumper for {self.dump_file_path} is not opened. Call update_session() first."
3476
)
35-
)
36-
_ = await self._file.write(data)
77+
return await self._file.write(data)

0 commit comments

Comments
 (0)