From cd26dd19329a2e3f56ce530d4a062a438c20ea06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9D=98=EC=A7=84?= Date: Mon, 27 Oct 2025 17:20:14 +0900 Subject: [PATCH 1/3] fix triton token2wav model cache thread unsafety --- runtime/triton_trtllm/model_repo/token2wav/1/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index 1e380520..53d5a54a 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -28,6 +28,7 @@ import os import logging +from uuid import uuid4 import torch from torch.utils.dlpack import to_dlpack @@ -235,17 +236,17 @@ def execute(self, requests): stream = True else: stream = False - request_id = request.request_id() + uuid = uuid4().hex audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, prompt_token=prompt_speech_tokens, prompt_feat=prompt_speech_feat, embedding=prompt_spk_embedding, token_offset=token_offset, - uuid=request_id, + uuid=uuid, stream=stream, finalize=finalize) if finalize: - self.token2wav_model.model.hift_cache_dict.pop(request_id) + self.token2wav_model.model.hift_cache_dict.pop(uuid) else: tts_mel, _ = self.token2wav_model.model.flow.inference( From fa2781405f8df3cc0bc02593400559c24f26e0de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9D=98=EC=A7=84?= Date: Mon, 27 Oct 2025 18:07:30 +0900 Subject: [PATCH 2/3] Revert "fix triton token2wav model cache thread unsafety" This reverts commit cd26dd19329a2e3f56ce530d4a062a438c20ea06. --- runtime/triton_trtllm/model_repo/token2wav/1/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/runtime/triton_trtllm/model_repo/token2wav/1/model.py b/runtime/triton_trtllm/model_repo/token2wav/1/model.py index 53d5a54a..1e380520 100644 --- a/runtime/triton_trtllm/model_repo/token2wav/1/model.py +++ b/runtime/triton_trtllm/model_repo/token2wav/1/model.py @@ -28,7 +28,6 @@ import os import logging -from uuid import uuid4 import torch from torch.utils.dlpack import to_dlpack @@ -236,17 +235,17 @@ def execute(self, requests): stream = True else: stream = False - uuid = uuid4().hex + request_id = request.request_id() audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens, prompt_token=prompt_speech_tokens, prompt_feat=prompt_speech_feat, embedding=prompt_spk_embedding, token_offset=token_offset, - uuid=uuid, + uuid=request_id, stream=stream, finalize=finalize) if finalize: - self.token2wav_model.model.hift_cache_dict.pop(uuid) + self.token2wav_model.model.hift_cache_dict.pop(request_id) else: tts_mel, _ = self.token2wav_model.model.flow.inference( From e8bf717333d0418c4421c456a14df92cda5a46bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9D=98=EC=A7=84?= Date: Mon, 27 Oct 2025 18:12:17 +0900 Subject: [PATCH 3/3] Fix: generate token2wav_request_id from cosyvoice2 - Since all token2wav requests within a single cosyvoice2 request must share the same request_id, modify the logic so that a new request_id is generated only if it does not already exist, and ensure that the same request_id is sent consistently. --- runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py index a2bfb307..94c82c6f 100644 --- a/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py +++ b/runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py @@ -28,6 +28,7 @@ import os import threading import time +from uuid import uuid4 import numpy as np import torch @@ -364,6 +365,7 @@ def execute(self, requests): # Generate semantic tokens with LLM generated_ids_iter = self.forward_llm(input_ids) + token2wav_request_id = request_id or str(uuid4()) if self.decoupled: response_sender = request.get_response_sender() @@ -392,7 +394,7 @@ def execute(self, requests): this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = self.forward_token2wav( - this_tts_speech_token, request_id, prompt_speech_tokens, + this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, False ) @@ -427,7 +429,7 @@ def execute(self, requests): time.sleep(0.02) this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device) - sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) + sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) @@ -441,7 +443,7 @@ def execute(self, requests): if generated_ids is None or len(generated_ids) == 0: raise pb_utils.TritonModelException("Generated IDs is None or empty") - audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) + audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding) # Prepare response audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))