Skip to content

Commit 5dc6829

Browse files
wangyimin-agoraYiminW
authored andcommitted
feat: tts base set metadata in audio frame
1 parent c32f783 commit 5dc6829

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

integration_tests/test_async_tts2_base/extension.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ async def request_tts(
4545
This method is called when the TTS request is made.
4646
It should yield audio data bytes.
4747
"""
48+
# Send audio_start to set current_audio_request_id (required for metadata)
49+
await self.send_tts_audio_start(request_id=t.request_id)
50+
4851
audio_data_bytes = [3, 100, 7]
4952
for b in audio_data_bytes:
5053
await self.send_tts_audio_data(bytearray(b))

integration_tests/test_async_tts2_base/tests/test_basic_tts.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, sample_rate) -> None:
2929
self.target_sample_rate = sample_rate
3030
self.received_frames = 0
3131
self.received_text_result:TTSTextResult = None
32+
self.expected_metadata = {
33+
"session_id": "test_session",
34+
"turn_id": 1
35+
}
3236

3337
async def on_start(self, ten_env: AsyncTenEnvTester) -> None:
3438
await asyncio.sleep(0.1)
@@ -75,6 +79,18 @@ async def on_audio_frame(
7579
== audio_frame.get_samples_per_channel() * 2
7680
)
7781

82+
# Verify metadata is attached to audio frame
83+
metadata_json, err = audio_frame.get_property_to_json("metadata")
84+
assert not err, f"Failed to get metadata from audio frame: {err}"
85+
86+
metadata = json.loads(metadata_json)
87+
ten_env.log_info(f"Audio frame metadata: {metadata}")
88+
89+
# Verify metadata matches what was sent in the request
90+
assert metadata == self.expected_metadata, (
91+
f"Metadata mismatch! Expected: {self.expected_metadata}, "
92+
f"Got: {metadata}"
93+
)
7894

7995
self.received_frames += 1
8096

@@ -85,7 +101,7 @@ async def on_audio_frame(
85101
f"Number of Channels: {audio_frame.get_number_of_channels()}"
86102
f"Received Frames: {self.received_frames}"
87103
)
88-
104+
89105
self.check_received(ten_env)
90106

91107
def check_received(self, ten_env: AsyncTenEnvTester):

interface/ten_ai_base/tts2.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def __init__(self, name: str):
9797
self.total_recv_audio_duration = 0
9898
self.total_recv_audio_chunks_len = 0
9999

100+
# Tracks which request_id's audio is currently being sent
101+
# Set in send_tts_audio_start(), reset in send_tts_audio_end() and flush
102+
# Used by send_tts_audio_data() to attach correct metadata to audio frames
103+
self.current_audio_request_id = None
104+
100105
def _can_transition_to(self, request_id: str, new_state: RequestState) -> bool:
101106
"""Check if state transition is valid."""
102107
current_state = self.request_states.get(request_id)
@@ -298,8 +303,9 @@ async def _flush_input_items(self):
298303
self.request_states.clear()
299304
self.metadatas.clear()
300305

301-
# Reset processing request ID
306+
# Reset processing request ID and current audio request ID
302307
self._processing_request_id = None
308+
self.current_audio_request_id = None
303309

304310
self.ten_env.log_debug("Cleared all request states, metadata, and pending messages after flush")
305311

@@ -405,6 +411,7 @@ async def send_tts_audio_data(self, audio_data: bytes, timestamp: int = 0) -> No
405411
)
406412
f.alloc_buf(len(combined_data))
407413
f.set_timestamp(timestamp)
414+
f.set_property_from_json("metadata", json.dumps(self.metadatas.get(self.current_audio_request_id, {})))
408415
buff = f.lock_buf()
409416
buff[:] = combined_data
410417
f.unlock_buf(buff)
@@ -443,6 +450,9 @@ async def send_tts_ttfb_metrics(
443450
async def send_tts_audio_start(
444451
self, request_id: str, turn_id: int = -1, extra_metadata: dict | None = None
445452
) -> None:
453+
# Set current_audio_request_id to track which request's audio is being sent
454+
self.current_audio_request_id = request_id
455+
446456
new_metadata = self.update_metadata(request_id, extra_metadata)
447457

448458
data = Data.create("tts_audio_start")
@@ -489,6 +499,10 @@ async def send_tts_audio_end(
489499
# Clean up metadata when audio_end is sent
490500
self.metadatas.pop(request_id, None)
491501

502+
# Reset current_audio_request_id (audio phase complete)
503+
if self.current_audio_request_id == request_id:
504+
self.current_audio_request_id = None
505+
492506
async def send_tts_error(
493507
self,
494508
request_id: str | None,
@@ -670,6 +684,10 @@ async def finish_request(
670684
# This is a defensive cleanup in case audio_end wasn't sent
671685
self.metadatas.pop(request_id, None)
672686

687+
# Defensive reset of current_audio_request_id for error paths
688+
if self.current_audio_request_id == request_id:
689+
self.current_audio_request_id = None
690+
673691
# Handle request completion and buffered messages release
674692
# Only process if this is the currently processing request
675693
if self._processing_request_id == request_id:

0 commit comments

Comments
 (0)