Skip to content

Commit 87569b4

Browse files
authored
enh(model_loaders): Add GPU support. (#75)
Set n_gpu_layers based on torch.cuda.is_available
1 parent 09faf73 commit 87569b4

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/document_to_podcast/inference/model_loaders.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Tuple
22

3+
import torch
34
from huggingface_hub import hf_hub_download
45
from llama_cpp import Llama
56
from outetts import GGUFModelConfig_v1, InterfaceGGUF
@@ -30,13 +31,12 @@ def load_llama_cpp_model(
3031
# 0 means that the model limit will be used, instead of the default (512) or other hardcoded value
3132
n_ctx=0,
3233
verbose=False,
34+
n_gpu_layers=-1 if torch.cuda.is_available() else 0,
3335
)
3436
return model
3537

3638

37-
def load_outetts_model(
38-
model_id: str, language: str = "en", device: str = "cpu"
39-
) -> InterfaceGGUF:
39+
def load_outetts_model(model_id: str, language: str = "en") -> InterfaceGGUF:
4040
"""
4141
Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
4242
@@ -47,43 +47,43 @@ def load_outetts_model(
4747
model_id (str): The model id to load.
4848
Format is expected to be `{org}/{repo}/{filename}`.
4949
language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
50-
device (str): The device to load the model on, such as "cuda:0" or "cpu".
5150
5251
Returns:
5352
PreTrainedModel: The loaded model.
5453
"""
55-
n_layers_on_gpu = 0 if device == "cpu" else -1
5654
model_version = model_id.split("-")[1]
5755

5856
org, repo, filename = model_id.split("/")
5957
local_path = hf_hub_download(repo_id=f"{org}/{repo}", filename=filename)
6058
model_config = GGUFModelConfig_v1(
61-
model_path=local_path, language=language, n_gpu_layers=n_layers_on_gpu
59+
model_path=local_path,
60+
language=language,
61+
n_gpu_layers=-1 if torch.cuda.is_available else 0,
62+
additional_model_config={"verbose": False},
6263
)
6364

6465
return InterfaceGGUF(model_version=model_version, cfg=model_config)
6566

6667

6768
def load_parler_tts_model_and_tokenizer(
68-
model_id: str, device: str = "cpu"
69+
model_id: str,
6970
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
7071
"""
7172
Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
7273
7374
Examples:
74-
>>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu")
75+
>>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1")
7576
7677
Args:
7778
model_id (str): The model id to load.
7879
Format is expected to be `{repo}/{filename}`.
79-
device (str): The device to load the model on, such as "cuda:0" or "cpu".
8080
8181
Returns:
8282
PreTrainedModel: The loaded model.
8383
"""
8484
from parler_tts import ParlerTTSForConditionalGeneration
8585

86-
model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device)
86+
model = ParlerTTSForConditionalGeneration.from_pretrained(model_id)
8787
tokenizer = AutoTokenizer.from_pretrained(model_id)
8888

8989
return model, tokenizer

0 commit comments

Comments
 (0)