Skip to content

Commit 1cfc5dd

Browse files
committed
add online trt export
1 parent 426c400 commit 1cfc5dd

File tree

13 files changed

+100
-167
lines changed

13 files changed

+100
-167
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ import torchaudio
128128

129129
**CosyVoice2 Usage**
130130
```python
131-
# NOTE if you want to use tensorRT to accerlate the flow matching inference, please set load_trt=True.
132-
# if you don't want to save tensorRT model on disk, please set environment variable `NOT_SAVE_TRT=1`.
133131
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
134132

135133
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference

cosyvoice/cli/cosyvoice.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
5353
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
5454
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
5555
if load_trt:
56-
self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
56+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
57+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
58+
self.fp16)
5759
del configs
5860

5961
def list_available_spks(self):
@@ -149,7 +151,9 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
149151
if load_jit:
150152
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
151153
if load_trt:
152-
self.model.load_trt('{}/flow.decoder.estimator'.format(model_dir), self.fp16)
154+
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
155+
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
156+
self.fp16)
153157
del configs
154158

155159
def inference_instruct(self, *args, **kwargs):

cosyvoice/cli/model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
import torch
1516
import numpy as np
1617
import threading
@@ -19,7 +20,7 @@
1920
from contextlib import nullcontext
2021
import uuid
2122
from cosyvoice.utils.common import fade_in_out
22-
from cosyvoice.trt.estimator_trt import EstimatorTRT
23+
from cosyvoice.utils.file_utils import convert_onnx_to_trt
2324

2425

2526
class CosyVoiceModel:
@@ -36,6 +37,9 @@ def __init__(self,
3637
self.fp16 = fp16
3738
self.llm.fp16 = fp16
3839
self.flow.fp16 = fp16
40+
if self.fp16 is True:
41+
self.llm.half()
42+
self.flow.half()
3943
self.token_min_hop_len = 2 * self.flow.input_frame_rate
4044
self.token_max_hop_len = 4 * self.flow.input_frame_rate
4145
self.token_overlap_len = 20
@@ -70,9 +74,6 @@ def load(self, llm_model, flow_model, hift_model):
7074
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
7175
self.hift.load_state_dict(hift_state_dict, strict=True)
7276
self.hift.to(self.device).eval()
73-
if self.fp16 is True:
74-
self.llm.half()
75-
self.flow.half()
7677

7778
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
7879
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
@@ -82,9 +83,17 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
8283
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
8384
self.flow.encoder = flow_encoder
8485

85-
def load_trt(self, flow_decoder_estimator_model, fp16):
86+
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
87+
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
88+
if not os.path.exists(flow_decoder_estimator_model):
89+
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
8690
del self.flow.decoder.estimator
87-
self.flow.decoder.estimator = EstimatorTRT(flow_decoder_estimator_model, self.device, fp16)
91+
import tensorrt as trt
92+
with open(flow_decoder_estimator_model, 'rb') as f:
93+
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
94+
if self.flow.decoder.estimator_engine is None:
95+
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
96+
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
8897

8998
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
9099
with self.llm_context:
@@ -269,6 +278,9 @@ def __init__(self,
269278
self.fp16 = fp16
270279
self.llm.fp16 = fp16
271280
self.flow.fp16 = fp16
281+
if self.fp16 is True:
282+
self.llm.half()
283+
self.flow.half()
272284
self.token_hop_len = 2 * self.flow.input_frame_rate
273285
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
274286
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate

cosyvoice/dataset/processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torch.nn.utils.rnn import pad_sequence
2222
import torch.nn.functional as F
2323

24-
torchaudio.set_audio_backend('soundfile')
2524

2625
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
2726

cosyvoice/flow/flow_matching.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
134134
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135135
# run trt engine
136136
self.estimator.execute_v2([x.contiguous().data_ptr(),
137-
mask.contiguous().data_ptr(),
138-
mu.contiguous().data_ptr(),
139-
t.contiguous().data_ptr(),
140-
spks.contiguous().data_ptr(),
141-
cond.contiguous().data_ptr(),
142-
x.data_ptr()])
137+
mask.contiguous().data_ptr(),
138+
mu.contiguous().data_ptr(),
139+
t.contiguous().data_ptr(),
140+
spks.contiguous().data_ptr(),
141+
cond.contiguous().data_ptr(),
142+
x.data_ptr()])
143143
return x
144144

145145
def compute_loss(self, x1, mask, mu, spks=None, cond=None):

cosyvoice/hifigan/discriminator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from torch.nn.utils import weight_norm
3+
from torch.nn.utils.parametrizations import weight_norm
44
from typing import List, Optional, Tuple
55
from einops import rearrange
66
from torchaudio.transforms import Spectrogram

cosyvoice/hifigan/f0_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import torch
1515
import torch.nn as nn
16-
from torch.nn.utils import weight_norm
16+
from torch.nn.utils.parametrizations import weight_norm
1717

1818

1919
class ConvRNNF0Predictor(nn.Module):

cosyvoice/hifigan/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.nn import Conv1d
2424
from torch.nn import ConvTranspose1d
2525
from torch.nn.utils import remove_weight_norm
26-
from torch.nn.utils import weight_norm
26+
from torch.nn.utils.parametrizations import weight_norm
2727
from torch.distributions.uniform import Uniform
2828

2929
from cosyvoice.transformer.activation import Snake

cosyvoice/trt/estimator_trt.py

Lines changed: 0 additions & 141 deletions
This file was deleted.

cosyvoice/utils/file_utils.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2-
# 2024 Alibaba Inc (authors: Xiang Lyu)
2+
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import json
17+
import tensorrt as trt
1718
import torchaudio
1819
import logging
1920
logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -45,3 +46,44 @@ def load_wav(wav, target_sr):
4546
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
4647
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
4748
return speech
49+
50+
51+
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
52+
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
53+
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
54+
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
55+
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
56+
57+
logging.info("Converting onnx to trt...")
58+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
59+
logger = trt.Logger(trt.Logger.INFO)
60+
builder = trt.Builder(logger)
61+
network = builder.create_network(network_flags)
62+
parser = trt.OnnxParser(network, logger)
63+
config = builder.create_builder_config()
64+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
65+
if fp16:
66+
config.set_flag(trt.BuilderFlag.FP16)
67+
profile = builder.create_optimization_profile()
68+
# load onnx model
69+
with open(onnx_model, "rb") as f:
70+
if not parser.parse(f.read()):
71+
for error in range(parser.num_errors):
72+
print(parser.get_error(error))
73+
raise ValueError('failed to parse {}'.format(onnx_model))
74+
# set input shapes
75+
for i in range(len(input_names)):
76+
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
77+
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
78+
# set input and output data type
79+
for i in range(network.num_inputs):
80+
input_tensor = network.get_input(i)
81+
input_tensor.dtype = tensor_dtype
82+
for i in range(network.num_outputs):
83+
output_tensor = network.get_output(i)
84+
output_tensor.dtype = tensor_dtype
85+
config.add_optimization_profile(profile)
86+
engine_bytes = builder.build_serialized_network(network, config)
87+
# save trt engine
88+
with open(trt_model, "wb") as f:
89+
f.write(engine_bytes)

0 commit comments

Comments
 (0)