1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import os
1415import torch
1516import numpy as np
1617import threading
1920from contextlib import nullcontext
2021import uuid
2122from cosyvoice .utils .common import fade_in_out
23+ from cosyvoice .utils .file_utils import convert_onnx_to_trt
2224
2325
2426class CosyVoiceModel :
@@ -35,6 +37,9 @@ def __init__(self,
3537 self .fp16 = fp16
3638 self .llm .fp16 = fp16
3739 self .flow .fp16 = fp16
40+ if self .fp16 is True :
41+ self .llm .half ()
42+ self .flow .half ()
3843 self .token_min_hop_len = 2 * self .flow .input_frame_rate
3944 self .token_max_hop_len = 4 * self .flow .input_frame_rate
4045 self .token_overlap_len = 20
@@ -69,9 +74,6 @@ def load(self, llm_model, flow_model, hift_model):
6974 hift_state_dict = {k .replace ('generator.' , '' ): v for k , v in torch .load (hift_model , map_location = self .device ).items ()}
7075 self .hift .load_state_dict (hift_state_dict , strict = True )
7176 self .hift .to (self .device ).eval ()
72- if self .fp16 is True :
73- self .llm .half ()
74- self .flow .half ()
7577
7678 def load_jit (self , llm_text_encoder_model , llm_llm_model , flow_encoder_model ):
7779 llm_text_encoder = torch .jit .load (llm_text_encoder_model , map_location = self .device )
@@ -81,7 +83,10 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
8183 flow_encoder = torch .jit .load (flow_encoder_model , map_location = self .device )
8284 self .flow .encoder = flow_encoder
8385
84- def load_trt (self , flow_decoder_estimator_model ):
86+ def load_trt (self , flow_decoder_estimator_model , flow_decoder_onnx_model , fp16 ):
87+ assert torch .cuda .is_available (), 'tensorrt only supports gpu!'
88+ if not os .path .exists (flow_decoder_estimator_model ):
89+ convert_onnx_to_trt (flow_decoder_estimator_model , flow_decoder_onnx_model , fp16 )
8590 del self .flow .decoder .estimator
8691 import tensorrt as trt
8792 with open (flow_decoder_estimator_model , 'rb' ) as f :
@@ -204,6 +209,7 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
204209 self .mel_overlap_dict .pop (this_uuid )
205210 self .hift_cache_dict .pop (this_uuid )
206211 self .flow_cache_dict .pop (this_uuid )
212+ torch .cuda .empty_cache ()
207213
208214 def vc (self , source_speech_token , flow_prompt_speech_token , prompt_speech_feat , flow_embedding , stream = False , speed = 1.0 , ** kwargs ):
209215 # this_uuid is used to track variables related to this inference thread
@@ -257,6 +263,7 @@ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat,
257263 self .llm_end_dict .pop (this_uuid )
258264 self .mel_overlap_dict .pop (this_uuid )
259265 self .hift_cache_dict .pop (this_uuid )
266+ torch .cuda .empty_cache ()
260267
261268
262269class CosyVoice2Model (CosyVoiceModel ):
@@ -273,6 +280,9 @@ def __init__(self,
273280 self .fp16 = fp16
274281 self .llm .fp16 = fp16
275282 self .flow .fp16 = fp16
283+ if self .fp16 is True :
284+ self .llm .half ()
285+ self .flow .half ()
276286 self .token_hop_len = 2 * self .flow .input_frame_rate
277287 # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
278288 self .flow .encoder .static_chunk_size = 2 * self .flow .input_frame_rate
@@ -385,3 +395,4 @@ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
385395 with self .lock :
386396 self .tts_speech_token_dict .pop (this_uuid )
387397 self .llm_end_dict .pop (this_uuid )
398+ torch .cuda .empty_cache ()
0 commit comments