Skip to content

Commit ba5db60

Browse files
authored
Merge pull request #1622 from GoyoUijin/Fix/token2wav-cache-thread-unsafe
fix triton token2wav model cache thread unsafety
2 parents 4c19646 + e8bf717 commit ba5db60

File tree

1 file changed

+5
-3
lines changed
  • runtime/triton_trtllm/model_repo/cosyvoice2/1

1 file changed

+5
-3
lines changed

runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import os
2929
import threading
3030
import time
31+
from uuid import uuid4
3132

3233
import numpy as np
3334
import torch
@@ -364,6 +365,7 @@ def execute(self, requests):
364365
# Generate semantic tokens with LLM
365366
generated_ids_iter = self.forward_llm(input_ids)
366367

368+
token2wav_request_id = request_id or str(uuid4())
367369
if self.decoupled:
368370
response_sender = request.get_response_sender()
369371

@@ -392,7 +394,7 @@ def execute(self, requests):
392394
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
393395

394396
sub_tts_speech = self.forward_token2wav(
395-
this_tts_speech_token, request_id, prompt_speech_tokens,
397+
this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
396398
prompt_speech_feat, prompt_spk_embedding, token_offset, False
397399
)
398400

@@ -427,7 +429,7 @@ def execute(self, requests):
427429
time.sleep(0.02)
428430

429431
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
430-
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
432+
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
431433
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
432434
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
433435
response_sender.send(inference_response)
@@ -441,7 +443,7 @@ def execute(self, requests):
441443
if generated_ids is None or len(generated_ids) == 0:
442444
raise pb_utils.TritonModelException("Generated IDs is None or empty")
443445

444-
audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
446+
audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
445447

446448
# Prepare response
447449
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

0 commit comments

Comments
 (0)