-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Description
模型初始化代码
cosyvoice3代码:
cosyvoice = AutoModel(
model_dir='CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B',
load_trt=True,
load_vllm=True,
fp16=False,
trt_concurrent=2
)cosyvoice2代码:
cosyvoice = AutoModel(
model_dir='CosyVoice/pretrained_models/CosyVoice2-0.5B',
load_jit=True,
load_trt=True,
load_vllm=True,
trt_concurrent=2,
)推理的文本内容
啊!我喜欢做很多很多事情哦~比如画画、唱歌、跳舞
指导音频
cosyvoice2音频效果:
cosyvoice2-pcm_1767153166.pcm.wav
cosyvoice3音频最终效果
cosyvoice3-pcm_1767150614.pcm.wav
完整代码
推理代码使用grpc提供服务,双流式推理,第一次推理使用instruct,之后的推理使用zero-shot,从第二次推理开始使用指导音频的embedding保留音色,使用第一次instruct推理结果音频做指导音频,取它的韵律。
上面的推理文本内容是通过grpc两次传递过来:
golang客户端:
ttsTextList = []string{"啊!我喜欢做很多很多事情哦~", "比如画画、唱歌、跳舞"}python推理服务端:
import io
import os
import re
import tempfile
import threading
from datetime import datetime
from typing import Iterator
import grpc
import logging
import numpy as np
import requests
import torch
import torchaudio
from CosyVoice.cosyvoice.cli.cosyvoice import AutoModel, CosyVoice3, CosyVoice2
from CosyVoice.cosyvoice.utils.common import set_all_random_seed
from protos.cosy_voice_grpc_pb2 import InferenceStreamReq, InferenceResp, InferenceConfig, ZeroShot, Instruct2
from protos.cosy_voice_grpc_pb2_grpc import CosyVoiceServiceServicer
from utils import common
from utils.common import normalize_audio_to_pcm_bytes, tts_text_generator_incr
from vllm import ModelRegistry
from CosyVoice.cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
log = logging.getLogger(__name__)
log.setLevel(logging.INFO) # 只输出 INFO 及以上级别的日志
def load_wav(file_path: str) -> str:
if file_path.startswith('http'):
# 检查本地是否存在
cache_file = '/tmp/' + file_path.split('/')[-1]
if os.path.exists(cache_file):
return cache_file
log.info(f'从网络下载参考音频:{file_path}')
response = requests.get(file_path)
if response.status_code // 100 != 2:
raise Exception("参考音频下载失败")
with open(cache_file, 'wb') as f:
f.write(response.content)
return cache_file
return file_path
def CosyvoiceZeroShot(cosyvoice: CosyVoice2, tts_text, zero_shot_spk_id, stream=True):
log.info(
f"ZeroShot-推理文本:text:{tts_text}, spk_id:{zero_shot_spk_id}, stream:{stream}")
for i, j in enumerate(cosyvoice.inference_zero_shot(
tts_text, '', None,
zero_shot_spk_id=zero_shot_spk_id,
stream=stream,
text_frontend=False
)):
# audio_bytes = (j['tts_speech'] * 32767).to(torch.int16).numpy().tobytes()
# 专业TTS音频处理:归一化 + 低频噪音抑制
audio_bytes = normalize_audio_to_pcm_bytes(j['tts_speech'], target_lufs=-16, sample_rate=cosyvoice.sample_rate)
yield audio_bytes
def CosyvoiceInstruct(cosyvoice: CosyVoice2, tts_text, zero_shot_spk_id, stream=True):
log.info(
f"Instruct2-推理文本:text:{tts_text}, spk_id:{zero_shot_spk_id}, stream:{stream}")
for i, j in enumerate(cosyvoice.inference_instruct2(
tts_text, '', None,
zero_shot_spk_id=zero_shot_spk_id,
stream=stream,
text_frontend=False
)):
# audio_bytes = (j['tts_speech'] * 32767).to(torch.int16).numpy().tobytes()
# 专业TTS音频处理:归一化 + 低频噪音抑制
audio_bytes = normalize_audio_to_pcm_bytes(j['tts_speech'], target_lufs=-16, sample_rate=cosyvoice.sample_rate)
yield audio_bytes
handle_id = np.uint32(0)
handle_id_lock = threading.Lock()
def handle_id_increment():
global handle_id
with handle_id_lock:
handle_id += 1
return handle_id
class CosyVoiceGrpcServicer(CosyVoiceServiceServicer):
def __init__(self):
super().__init__()
# set_all_random_seed(1761032004)
cosyvoice = AutoModel(
model_dir='CosyVoice/pretrained_models/CosyVoice2-0.5B',
load_jit=True,
load_trt=True,
load_vllm=True,
trt_concurrent=2,
)
self.cosyvoice = cosyvoice
def ModelPreHot(self):
"""模型预热"""
tts_text = "你对得起我吗?昨天刚给你买的玩具你居然给我弄丢了"
instruct_text = "优雅的说"
ref_audio = "https://files.quantekeji.com/timbre/illustrate_200018_zh_cn.wav"
ref_audio = load_wav(ref_audio)
prompt_text = f'{instruct_text}<|endofprompt|>'
zero_shot_spk_id = 'spk_pre_hot'
common.add_zero_shot_spk(self.cosyvoice, prompt_text, ref_audio, zero_shot_spk_id)
# 执行推理
for i, resp in enumerate(CosyvoiceInstruct(self.cosyvoice, tts_text, zero_shot_spk_id, stream=False)):
log.info(f"模型预热:chunk_{i}, sample_rate:{self.cosyvoice.sample_rate}")
# 预热时也检查音频质量
log.debug(f"模型预热 complete")
def InferenceStream(self, request_iterator: Iterator[InferenceStreamReq], context: grpc.ServicerContext):
tts_config: InferenceConfig
sr = 24000
silence_audio = torch.zeros(int(500 / 1000 * sr), dtype=torch.int16).numpy().tobytes()
tmp_ref_audio = ''
current_handle_id = handle_id_increment()
zero_shot_spk_id = f'spk_id_{current_handle_id}'
try:
# 配置config
first_req = request_iterator.__next__()
if first_req.HasField("config"):
tts_config = first_req.config
if not tts_config.ref_audio:
yield InferenceResp(err='must give ref_audio')
return
else:
yield InferenceResp(err='first tts data must is config')
return
ref_audio = load_wav(tts_config.ref_audio)
if tts_config.kind == Instruct2:
prompt_text = f'{tts_config.instruct_text}<|endofprompt|>'
else:
prompt_text = f'{tts_config.instruct_text}'
common.add_zero_shot_spk(self.cosyvoice, prompt_text, ref_audio, zero_shot_spk_id)
# 保存用于测试
all_pcm_bytes = bytearray()
# 执行推理
# set_all_random_seed(int(datetime.now().timestamp()))
inference_first = True
for req_index, req in enumerate(request_iterator):
# 去除特殊字符
req.text = re.sub(r'\W', '', req.text)
if not req.text.strip():
continue
stream = False
if inference_first and tts_config.stream:
stream = True
inference_first = False
if tts_config.kind == Instruct2:
tts_fn = CosyvoiceInstruct(self.cosyvoice, req.text, zero_shot_spk_id, stream=stream)
else:
tts_fn = CosyvoiceZeroShot(self.cosyvoice, req.text, zero_shot_spk_id, stream=stream)
# 积攒本次音频
current_audio = bytes()
for i, audio_bytes in enumerate(tts_fn):
current_audio += audio_bytes
yield InferenceResp(pcm=audio_bytes)
if not context.is_active():
log.info("InferenceStream 推理中断,本次执行结束")
return
# 如果要测试则开启这里
# all_pcm_bytes.extend(audio_bytes)
# 段落之间增加静音音频
# yield InferenceResp(pcm=silence_audio)
if len(current_audio) > 0:
# 存储最后一段音频为下一次的指导音频
current_pcm_data = np.frombuffer(current_audio, dtype=np.int16)
# 如果超过20秒则阶段只保留最后的
if len(current_pcm_data) > sr * 20:
log.info("新的指导音频-超过20秒触发截断")
current_pcm_data = current_pcm_data[-sr * 20:]
else:
log.info(f"新的指导音频-长度:{len(current_pcm_data) / sr}秒")
current_pcm_tensor = torch.from_numpy(current_pcm_data.copy()).unsqueeze(0)
tmp_ref_audio = f'output/ref_audio_{current_handle_id}.wav'
torchaudio.save(tmp_ref_audio, current_pcm_tensor, sr)
tts_config.kind = ZeroShot
prompt_text = f'{req.text}'
common.update_zero_shot_spk(self.cosyvoice, prompt_text, tmp_ref_audio, zero_shot_spk_id)
if len(all_pcm_bytes) > 0:
all_pcm_data = np.frombuffer(all_pcm_bytes, dtype=np.int16)
all_pcm_tensor = torch.from_numpy(all_pcm_data).unsqueeze(0)
torchaudio.save('output/bi_{}.wav'.format(datetime.now().strftime("%H-%M-%S")), all_pcm_tensor, sr)
except Exception as e:
log.exception(e)
finally:
common.del_zero_shot_spk(self.cosyvoice, zero_shot_spk_id)
if tmp_ref_audio:
os.remove(tmp_ref_audio)
log.info(f"InferenceStream 推理结束")
passMetadata
Metadata
Assignees
Labels
No labels