2525
2626class CosyVoice :
2727
28- def __init__ (self , model_dir , load_jit = True , load_onnx = False , fp16 = True ):
28+ def __init__ (self , model_dir , load_jit = False , load_trt = False , fp16 = False ):
2929 self .instruct = True if '-Instruct' in model_dir else False
3030 self .model_dir = model_dir
31+ self .fp16 = fp16
3132 if not os .path .exists (model_dir ):
3233 model_dir = snapshot_download (model_dir )
3334 with open ('{}/cosyvoice.yaml' .format (model_dir ), 'r' ) as f :
3435 configs = load_hyperpyyaml (f )
35- assert get_model_type (configs ) == CosyVoiceModel , 'do not use {} for CosyVoice initialization!' .format (model_dir )
36+ assert get_model_type (configs ) != CosyVoice2Model , 'do not use {} for CosyVoice initialization!' .format (model_dir )
3637 self .frontend = CosyVoiceFrontEnd (configs ['get_tokenizer' ],
3738 configs ['feat_extractor' ],
3839 '{}/campplus.onnx' .format (model_dir ),
3940 '{}/speech_tokenizer_v1.onnx' .format (model_dir ),
4041 '{}/spk2info.pt' .format (model_dir ),
4142 configs ['allowed_special' ])
4243 self .sample_rate = configs ['sample_rate' ]
43- if torch .cuda .is_available () is False and (fp16 is True or load_jit is True ):
44- load_jit = False
45- fp16 = False
46- logging .warning ('cpu do not support fp16 and jit, force set to False' )
44+ if torch .cuda .is_available () is False and (load_jit is True or load_trt is True or fp16 is True ):
45+ load_jit , load_trt , fp16 = False , False , False
46+ logging .warning ('no cuda device, set load_jit/load_trt/fp16 to False' )
4747 self .model = CosyVoiceModel (configs ['llm' ], configs ['flow' ], configs ['hift' ], fp16 )
4848 self .model .load ('{}/llm.pt' .format (model_dir ),
4949 '{}/flow.pt' .format (model_dir ),
5050 '{}/hift.pt' .format (model_dir ))
5151 if load_jit :
52- self .model .load_jit ('{}/llm.text_encoder.fp16 .zip' .format (model_dir ),
53- '{}/llm.llm.fp16 .zip' .format (model_dir ),
54- '{}/flow.encoder.fp32 .zip' .format (model_dir ))
55- if load_onnx :
56- self .model .load_onnx ('{}/flow.decoder.estimator.fp32.onnx ' .format (model_dir ))
52+ self .model .load_jit ('{}/llm.text_encoder.{} .zip' .format (model_dir , 'fp16' if self . fp16 is True else 'fp32' ),
53+ '{}/llm.llm.{} .zip' .format (model_dir , 'fp16' if self . fp16 is True else 'fp32' ),
54+ '{}/flow.encoder.{} .zip' .format (model_dir , 'fp16' if self . fp16 is True else 'fp32' ))
55+ if load_trt :
56+ self .model .load_trt ('{}/flow.decoder.estimator.{}.v100.plan ' .format (model_dir , 'fp16' if self . fp16 is True else 'fp32' ))
5757 del configs
5858
5959 def list_available_spks (self ):
@@ -123,9 +123,10 @@ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed
123123
124124class CosyVoice2 (CosyVoice ):
125125
126- def __init__ (self , model_dir , load_jit = False , load_onnx = False , load_trt = False ):
126+ def __init__ (self , model_dir , load_jit = False , load_trt = False , fp16 = False ):
127127 self .instruct = True if '-Instruct' in model_dir else False
128128 self .model_dir = model_dir
129+ self .fp16 = fp16
129130 if not os .path .exists (model_dir ):
130131 model_dir = snapshot_download (model_dir )
131132 with open ('{}/cosyvoice.yaml' .format (model_dir ), 'r' ) as f :
@@ -138,22 +139,17 @@ def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
138139 '{}/spk2info.pt' .format (model_dir ),
139140 configs ['allowed_special' ])
140141 self .sample_rate = configs ['sample_rate' ]
141- if torch .cuda .is_available () is False and load_jit is True :
142- load_jit = False
143- logging .warning ('cpu do not support jit, force set to False' )
144- self .model = CosyVoice2Model (configs ['llm' ], configs ['flow' ], configs ['hift' ])
142+ if torch .cuda .is_available () is False and ( load_jit is True or load_trt is True or fp16 is True ) :
143+ load_jit , load_trt , fp16 = False , False , False
144+ logging .warning ('no cuda device, set load_jit/load_trt/fp16 to False' )
145+ self .model = CosyVoice2Model (configs ['llm' ], configs ['flow' ], configs ['hift' ], fp16 )
145146 self .model .load ('{}/llm.pt' .format (model_dir ),
146147 '{}/flow.pt' .format (model_dir ),
147148 '{}/hift.pt' .format (model_dir ))
148149 if load_jit :
149- self .model .load_jit ('{}/flow.encoder.fp32.zip' .format (model_dir ))
150- if load_trt is True and load_onnx is True :
151- load_onnx = False
152- logging .warning ('can not set both load_trt and load_onnx to True, force set load_onnx to False' )
153- if load_onnx :
154- self .model .load_onnx ('{}/flow.decoder.estimator.fp32.onnx' .format (model_dir ))
150+ self .model .load_jit ('{}/flow.encoder.{}.zip' .format (model_dir , 'fp16' if self .fp16 is True else 'fp32' ))
155151 if load_trt :
156- self .model .load_trt ('{}/flow.decoder.estimator.fp16.Volta .plan' .format (model_dir ))
152+ self .model .load_trt ('{}/flow.decoder.estimator.{}.v100 .plan' .format (model_dir , 'fp16' if self . fp16 is True else 'fp32' ))
157153 del configs
158154
159155 def inference_instruct (self , * args , ** kwargs ):
0 commit comments