11from typing import Tuple
22
3+ import torch
34from huggingface_hub import hf_hub_download
45from llama_cpp import Llama
56from outetts import GGUFModelConfig_v1 , InterfaceGGUF
@@ -30,13 +31,12 @@ def load_llama_cpp_model(
3031 # 0 means that the model limit will be used, instead of the default (512) or other hardcoded value
3132 n_ctx = 0 ,
3233 verbose = False ,
34+ n_gpu_layers = - 1 if torch .cuda .is_available () else 0 ,
3335 )
3436 return model
3537
3638
37- def load_outetts_model (
38- model_id : str , language : str = "en" , device : str = "cpu"
39- ) -> InterfaceGGUF :
39+ def load_outetts_model (model_id : str , language : str = "en" ) -> InterfaceGGUF :
4040 """
4141 Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
4242
@@ -47,43 +47,43 @@ def load_outetts_model(
4747 model_id (str): The model id to load.
4848 Format is expected to be `{org}/{repo}/{filename}`.
4949 language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
50- device (str): The device to load the model on, such as "cuda:0" or "cpu".
5150
5251 Returns:
5352 PreTrainedModel: The loaded model.
5453 """
55- n_layers_on_gpu = 0 if device == "cpu" else - 1
5654 model_version = model_id .split ("-" )[1 ]
5755
5856 org , repo , filename = model_id .split ("/" )
5957 local_path = hf_hub_download (repo_id = f"{ org } /{ repo } " , filename = filename )
6058 model_config = GGUFModelConfig_v1 (
61- model_path = local_path , language = language , n_gpu_layers = n_layers_on_gpu
59+ model_path = local_path ,
60+ language = language ,
61+ n_gpu_layers = - 1 if torch .cuda .is_available else 0 ,
62+ additional_model_config = {"verbose" : False },
6263 )
6364
6465 return InterfaceGGUF (model_version = model_version , cfg = model_config )
6566
6667
6768def load_parler_tts_model_and_tokenizer (
68- model_id : str , device : str = "cpu"
69+ model_id : str ,
6970) -> Tuple [PreTrainedModel , PreTrainedTokenizerBase ]:
7071 """
7172 Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
7273
7374 Examples:
74- >>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu" )
75+ >>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1")
7576
7677 Args:
7778 model_id (str): The model id to load.
7879 Format is expected to be `{repo}/{filename}`.
79- device (str): The device to load the model on, such as "cuda:0" or "cpu".
8080
8181 Returns:
8282 PreTrainedModel: The loaded model.
8383 """
8484 from parler_tts import ParlerTTSForConditionalGeneration
8585
86- model = ParlerTTSForConditionalGeneration .from_pretrained (model_id ). to ( device )
86+ model = ParlerTTSForConditionalGeneration .from_pretrained (model_id )
8787 tokenizer = AutoTokenizer .from_pretrained (model_id )
8888
8989 return model , tokenizer
0 commit comments