33# Licensed under the Apache License, Version 2.0.
44# See the LICENSE file for more information.
55#
6+ """ASR base class: connection, buffering, metrics and result helpers."""
67from abc import abstractmethod
8+ import asyncio
9+ from functools import wraps
10+ import json
11+ import os
712from typing import Any , final
813import uuid
914
10- from .struct import ASRResult
11- from .types import ASRBufferConfig , ASRBufferConfigModeDiscard , ASRBufferConfigModeKeep
15+ from ten_runtime import (
16+ AsyncExtension ,
17+ AsyncTenEnv ,
18+ Cmd ,
19+ Data ,
20+ AudioFrame ,
21+ StatusCode ,
22+ CmdResult ,
23+ )
1224
25+ from .const import (
26+ DATA_IN_ASR_FINALIZE ,
27+ DATA_OUT_ASR_FINALIZE_END ,
28+ DATA_OUT_METRICS ,
29+ PROPERTY_KEY_DUMP ,
30+ PROPERTY_KEY_DUMP_PATH ,
31+ PROPERTY_KEY_METADATA ,
32+ PROPERTY_KEY_SESSION_ID ,
33+ )
34+ from .dumper import Dumper
1335from .message import (
1436 ModuleError ,
1537 ModuleErrorCode ,
1840 ModuleMetrics ,
1941 ModuleType ,
2042)
43+ from .struct import ASRResult
2144from .timeline import AudioTimeline
22- from .const import (
23- DATA_IN_ASR_FINALIZE ,
24- DATA_OUT_ASR_FINALIZE_END ,
25- DATA_OUT_METRICS ,
26- PROPERTY_KEY_METADATA ,
27- PROPERTY_KEY_SESSION_ID ,
28- )
29- from ten_runtime import (
30- AsyncExtension ,
31- AsyncTenEnv ,
32- Cmd ,
33- Data ,
34- AudioFrame ,
35- StatusCode ,
36- CmdResult ,
37- )
38- import asyncio
39- import json
45+ from .types import ASRBufferConfig , ASRBufferConfigModeDiscard , ASRBufferConfigModeKeep
46+
4047
48+ class AsyncASRBaseExtension (
49+ AsyncExtension
50+ ): # pylint: disable=too-many-instance-attributes, too-many-public-methods
51+ """Asynchronous base class for ASR modules."""
4152
42- class AsyncASRBaseExtension (AsyncExtension ):
4353 def __init__ (self , name : str ):
4454 super ().__init__ (name )
4555
@@ -52,6 +62,7 @@ def __init__(self, name: str):
5262 self .buffered_frames_size = 0
5363 self .audio_frames_queue = asyncio .Queue [AudioFrame ]()
5464 self .audio_timeline = AudioTimeline ()
65+ self .dumper : Dumper | None = None
5566 self .audio_actual_send_metrics_task : asyncio .Task [None ] | None = None
5667 self .uuid = self ._get_uuid () # Unique identifier for the current final turn
5768
@@ -68,6 +79,18 @@ async def on_init(self, ten_env: AsyncTenEnv) -> None:
6879 self .ten_env = ten_env
6980 asyncio .create_task (self ._audio_frame_consumer ())
7081
82+ enable_dump , err = await ten_env .get_property_bool (PROPERTY_KEY_DUMP )
83+ if err :
84+ ten_env .log_info (f"dump not set, disable dump: { err } " )
85+ elif enable_dump :
86+ dump_path , err = await ten_env .get_property_string (PROPERTY_KEY_DUMP_PATH )
87+ if err :
88+ ten_env .log_warn (f"dump_path not set, use current directory: { err } " )
89+ dump_path = "."
90+
91+ dump_file_path = os .path .join (dump_path , self .dump_file_name ())
92+ self .dumper = Dumper (dump_file_path , None )
93+
7194 async def on_start (self , ten_env : AsyncTenEnv ) -> None :
7295 ten_env .log_info ("on_start" )
7396
@@ -127,6 +150,10 @@ def vendor(self) -> str:
127150 """Get the name of the ASR vendor."""
128151 raise NotImplementedError ("This method should be implemented in subclasses." )
129152
153+ def dump_file_name (self ) -> str :
154+ """Return the base dump filename."""
155+ return f"{ self .name } _out.pcm"
156+
130157 @abstractmethod
131158 async def start_connection (self ) -> None :
132159 """Start the connection to the ASR service."""
@@ -163,6 +190,26 @@ def input_audio_sample_width(self) -> int:
163190 """
164191 return 2
165192
193+ # Automatically wrap subclass start_connection to update dumper session first
194+ def __init_subclass__ (cls , ** kwargs ):
195+ super ().__init_subclass__ (** kwargs )
196+ orig = cls .__dict__ .get ("start_connection" )
197+ if orig is None :
198+ return
199+
200+ # Only wrap coroutine functions
201+ @wraps (orig )
202+ async def wrapped (self : AsyncASRBaseExtension , * args , ** kw ):
203+ try :
204+ if self .dumper is not None :
205+ await self .dumper .update_session ()
206+ except Exception as e : # pylint: disable=broad-exception-caught
207+ if self .ten_env is not None :
208+ self .ten_env .log_error (f"auto update_session failed: { e } " )
209+ return await orig (self , * args , ** kw )
210+
211+ setattr (cls , "start_connection" , wrapped )
212+
166213 def buffer_strategy (self ) -> ASRBufferConfig :
167214 """
168215 Get the buffer strategy for audio frames when not connected
@@ -442,7 +489,7 @@ async def _audio_frame_consumer(self) -> None:
442489 await self ._handle_audio_frame (self .ten_env , audio_frame )
443490 except asyncio .CancelledError :
444491 break
445- except Exception as e :
492+ except Exception as e : # pylint: disable=broad-exception-caught
446493 self .ten_env .log_error (f"Error consuming audio frame: { e } " )
447494
448495 async def _send_audio_actual_send_metrics_task (self ) -> None :
0 commit comments