diff --git a/cosyvoice/bin/inference.py b/cosyvoice/bin/inference.py index 00b33729b..70b445bb6 100644 --- a/cosyvoice/bin/inference.py +++ b/cosyvoice/bin/inference.py @@ -16,7 +16,8 @@ import argparse import logging -logging.getLogger('matplotlib').setLevel(logging.WARNING) + +logging.getLogger("matplotlib").setLevel(logging.WARNING) import os import torch from torch.utils.data import DataLoader @@ -53,13 +54,20 @@ def get_args(): def main(): args = get_args() - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s %(levelname)s %(message)s') - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s") + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Init cosyvoice models from configs - use_cuda = args.gpu >= 0 and torch.cuda.is_available() - device = torch.device('cuda' if use_cuda else 'cpu') + + if torch.cuda.is_available(): + device = torch.device("cuda:{}".format(args.gpu)) + elif torch.backends.mps.is_available(): + device = torch.device("mps") + elif torch.xpu.is_available(): + device = torch.device("xpu") + else: + device = torch.device("cpu") + try: with open(args.config, 'r') as f: configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path}) @@ -74,15 +82,14 @@ def main(): model.load(args.llm_model, args.flow_model, args.hifigan_model) - test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, - tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) + test_dataset = Dataset(args.prompt_data, data_pipeline=configs["data_pipeline"], mode="inference", shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) sample_rate = configs['sample_rate'] del configs os.makedirs(args.result_dir, exist_ok=True) - fn = os.path.join(args.result_dir, 'wav.scp') - f = open(fn, 'w') + fn = os.path.join(args.result_dir, "wav.scp") + f = open(fn, "w") with torch.no_grad(): for _, batch in tqdm(enumerate(test_data_loader)): utts = batch["utts"] @@ -98,19 +105,26 @@ def main(): speech_feat_len = batch["speech_feat_len"].to(device) utt_embedding = batch["utt_embedding"].to(device) spk_embedding = batch["spk_embedding"].to(device) - if args.mode == 'sft': - model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, - 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} + if args.mode == "sft": + model_input = {"text": tts_text_token, "text_len": tts_text_token_len, "llm_embedding": spk_embedding, "flow_embedding": spk_embedding} else: - model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, - 'prompt_text': text_token, 'prompt_text_len': text_token_len, - 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, - 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, - 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, - 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} + model_input = { + "text": tts_text_token, + "text_len": tts_text_token_len, + "prompt_text": text_token, + "prompt_text_len": text_token_len, + "llm_prompt_speech_token": speech_token, + "llm_prompt_speech_token_len": speech_token_len, + "flow_prompt_speech_token": speech_token, + "flow_prompt_speech_token_len": speech_token_len, + "prompt_speech_feat": speech_feat, + "prompt_speech_feat_len": speech_feat_len, + "llm_embedding": utt_embedding, + "flow_embedding": utt_embedding, + } tts_speeches = [] for model_output in model.tts(**model_input): - tts_speeches.append(model_output['tts_speech']) + tts_speeches.append(model_output["tts_speech"]) tts_speeches = torch.concat(tts_speeches, dim=1) tts_key = '{}_{}'.format(utts[0], tts_index[0]) tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) @@ -118,8 +132,8 @@ def main(): f.write('{} {}\n'.format(tts_key, tts_fn)) f.flush() f.close() - logging.info('Result wav.scp saved in {}'.format(fn)) + logging.info("Result wav.scp saved in {}".format(fn)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index a7bfab4f6..30d7decae 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -23,7 +23,6 @@ from cosyvoice.utils.file_utils import logging from cosyvoice.utils.class_utils import get_model_type - class CosyVoice: def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): @@ -45,7 +44,8 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): '{}/spk2info.pt'.format(model_dir), configs['allowed_special']) self.sample_rate = configs['sample_rate'] - if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True): + + if self.gpu_available() and (load_jit is True or load_trt is True or fp16 is True): load_jit, load_trt, fp16 = False, False, False logging.warning('no cuda device, set load_jit/load_trt/fp16 to False') self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16) @@ -62,6 +62,12 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): self.fp16) del configs + def gpu_available(self) -> bool: + """check torch GPU device""" + if torch.cuda.is_available() or torch.backends.mps.is_available() or torch.xpu.is_available(): + return True + return False + def list_available_spks(self): spks = list(self.frontend.spk2info.keys()) return spks diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py index 8770e3120..4ae797ee5 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -35,6 +35,16 @@ from cosyvoice.utils.file_utils import logging from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation +def set_device() -> torch.device: + """Assign GPU device if possible""" + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + elif torch.xpu.is_available(): + return torch.device("xpu") + else: + return torch.device("cpu") class CosyVoiceFrontEnd: @@ -45,9 +55,12 @@ def __init__(self, speech_tokenizer_model: str, spk2info: str = '', allowed_special: str = 'all'): + + self.device = set_device() + self.tokenizer = get_tokenizer() self.feat_extractor = feat_extractor - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 20ddad035..9e9f8b2e1 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -24,14 +24,38 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt +def set_device() -> torch.device: + """Assign GPU device if possible""" + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + elif torch.xpu.is_available(): + return torch.device("xpu") + else: + return torch.device("cpu") + +def clear_cache() -> None: + """Empty device caches""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.backends.mps.is_available(): + torch.mps.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() + + class CosyVoiceModel: def __init__(self, llm: torch.nn.Module, flow: torch.nn.Module, hift: torch.nn.Module, + fp16: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.device = set_device() + self.llm = llm self.flow = flow self.hift = hift @@ -81,7 +105,7 @@ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): self.flow.encoder = flow_encoder def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16): - assert torch.cuda.is_available(), 'tensorrt only supports gpu!' + assert torch.cuda.is_available() or torch.backends.mps.is_available() or torch.xpu.is_available(), 'tensorrt only supports gpu!' if not os.path.exists(flow_decoder_estimator_model): convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16) if os.path.getsize(flow_decoder_estimator_model) == 0: @@ -231,7 +255,10 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze self.mel_overlap_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + import gc + gc.collect() + clear_cache() + class CosyVoice2Model(CosyVoiceModel): @@ -242,7 +269,16 @@ def __init__(self, hift: torch.nn.Module, fp16: bool = False, use_flow_cache: bool = False): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if torch.cuda.is_available(): + self.device = torch.device('cuda') + elif torch.backends.mps.is_available(): + self.device = torch.device('mps') + elif torch.xpu.is_available(): + self.device = torch.device('xpu') + else: + self.device = torch.device('cpu') + self.llm = llm self.flow = flow self.hift = hift @@ -405,4 +441,5 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze self.llm_end_dict.pop(this_uuid) self.hift_cache_dict.pop(this_uuid) self.flow_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() + + clear_cache() diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py index 3e61a8c06..c575a25c6 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -152,7 +152,13 @@ def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if torch.cuda.is_available() is True: + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + elif torch.backends.mps.is_available() is True: + torch.mps.manual_seed(seed) + elif torch.xpu.is_available() is True: + torch.xpu.manual_seed(seed) def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: diff --git a/cosyvoice/utils/executor.py b/cosyvoice/utils/executor.py index 8c38bf016..efef6fea2 100644 --- a/cosyvoice/utils/executor.py +++ b/cosyvoice/utils/executor.py @@ -30,7 +30,13 @@ def __init__(self, gan: bool = False): self.step = 0 self.epoch = 0 self.rank = int(os.environ.get('RANK', 0)) - self.device = torch.device('cuda:{}'.format(self.rank)) + self.device = torch.device("cpu") + if torch.cuda.is_available(): + self.device = torch.device('cuda:{}'.format(self.rank)) + elif torch.backends.mps.is_available(): + self.device = torch.device('mps') + elif torch.xpu.is_available(): + self.device = torch.device('xpu') def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join): ''' Train one epoch diff --git a/requirements_xpu.txt b/requirements_xpu.txt new file mode 100644 index 000000000..8812ccaa9 --- /dev/null +++ b/requirements_xpu.txt @@ -0,0 +1,173 @@ +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/bmg/us/ +aiofiles==23.2.1 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.12 +aiosignal==1.3.2 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.8.0 +attrs==25.1.0 +audioread==3.0.1 +beautifulsoup4==4.13.3 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 +coloredlogs==15.0.1 +conformer==0.3.2 +contourpy==1.3.1 +cycler==0.12.1 +decorator==5.1.1 +deepspeed==0.16.3 +diffusers==0.32.2 +dpcpp-cpp-rt==2025.0.4 +einops==0.8.0 +fastapi==0.115.8 +fastapi-cli==0.0.7 +ffmpy==0.5.0 +filelock==3.17.0 +flatbuffers==25.1.24 +fonttools==4.55.8 +frozenlist==1.5.0 +fsspec==2025.2.0 +gdown==5.2.0 +gradio==5.15.0 +gradio_client==1.7.0 +grpcio==1.70.0 +grpcio-tools==1.70.0 +h11==0.14.0 +h5py==3.12.1 +hjson==3.1.0 +httpcore==1.0.7 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.28.1 +humanfriendly==10.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +idna==3.10 +importlib_metadata==8.6.1 +inflect==7.5.0 +intel-cmplr-lib-rt==2025.0.4 +intel-cmplr-lib-ur==2025.0.4 +intel-cmplr-lic-rt==2025.0.4 +intel-opencl-rt==2025.0.4 +intel-openmp==2025.0.4 +intel-sycl-rt==2025.0.4 +intel_extension_for_pytorch==2.5.10+xpu +Jinja2==3.1.5 +joblib==1.4.2 +kaldifst==1.7.13 +kiwisolver==1.4.8 +lazy_loader==0.4 +librosa==0.10.2.post1 +lightning==2.5.0.post0 +lightning-utilities==0.12.0 +llvmlite==0.44.0 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.10.0 +mdurl==0.1.2 +mkl==2025.0.1 +mkl-dpcpp==2025.0.1 +modelscope==1.22.3 +more-itertools==10.6.0 +mpmath==1.3.0 +msgpack==1.1.0 +multidict==6.1.0 +networkx==3.4.2 +ninja==1.11.1.3 +numba==0.61.0 +numexpr==2.10.2 +numpy==1.26.4 +nvidia-ml-py==12.570.86 +omegaconf==2.3.0 +onemkl-sycl-blas==2025.0.1 +onemkl-sycl-datafitting==2025.0.1 +onemkl-sycl-dft==2025.0.1 +onemkl-sycl-lapack==2025.0.1 +onemkl-sycl-rng==2025.0.1 +onemkl-sycl-sparse==2025.0.1 +onemkl-sycl-stats==2025.0.1 +onemkl-sycl-vm==2025.0.1 +onnx==1.17.0 +onnxruntime-directml==1.20.1 +openai-whisper==20240930 +opencv-contrib-python==4.11.0.86 +opencv-python==4.11.0.86 +orjson==3.10.15 +packaging==24.2 +pandas==2.2.3 +pillow==11.1.0 +platformdirs==4.3.6 +pooch==1.8.2 +propcache==0.2.1 +protobuf==5.29.3 +psutil==6.1.1 +py-cpuinfo==9.0.0 +py3nvml==0.2.7 +pyarrow==19.0.0 +pycparser==2.22 +pydantic==2.10.6 +pydantic_core==2.27.2 +pydub==0.25.1 +Pygments==2.19.1 +pyparsing==3.2.1 +pyreadline3==3.5.4 +PySocks==1.7.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +python-multipart==0.0.20 +pytorch-lightning==2.5.0.post0 +pytz==2025.1 +pyworld==0.3.5 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +rich-toolkit==0.13.2 +ruamel.yaml==0.18.10 +ruamel.yaml.clib==0.2.12 +ruff==0.9.5 +safehttpx==0.1.6 +safetensors==0.5.2 +scikit-learn==1.6.1 +scipy==1.15.1 +semantic-version==2.10.0 +setuptools==75.8.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soundfile==0.13.1 +soupsieve==2.6 +soxr==0.5.0.post1 +starlette==0.45.3 +sympy==1.13.1 +tbb==2022.0.0 +tcmlib==1.2.0 +threadpoolctl==3.5.0 +tiktoken==0.8.0 +tokenizers==0.21.0 +tomlkit==0.13.2 +torch==2.5.1+cxx11.abi +torchaudio==2.5.1+cxx11.abi +torchmetrics==1.6.1 +torchvision==0.20.1+cxx11.abi +tqdm==4.67.1 +transformers==4.48.2 +typeguard==4.4.1 +typer==0.15.1 +typing_extensions==4.12.2 +tzdata==2025.1 +umf==0.9.1 +urllib3==2.3.0 +uvicorn==0.34.0 +watchfiles==1.0.4 +websockets==14.2 +wetext==0.0.3 +wget==3.2 +wheel==0.45.1 +xmltodict==0.14.2 +yarl==1.18.3 +zipp==3.21.0