Skip to content

Commit 4fc7f7d

Browse files
yxf0314gitlawr
authored andcommitted
feat: Control which GPU the model selects by the CUDA_VISIBLE_DEVICES environment variable
1 parent 04f4d9e commit 4fc7f7d

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

vox_box/backends/stt/faster_whisper.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
preprocessor_config_path = os.path.join(
2929
self._cfg.model, "preprocessor_config.json"
3030
)
31+
3132
if os.path.exists(preprocessor_config_path):
3233
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
3334
self._preprocessor_config_json = json.load(f)
@@ -40,22 +41,17 @@ def load(self):
4041
if self._cfg.device == "cpu":
4142
cpu_threads = 8
4243

44+
device = self._cfg.device
45+
if device.startswith("cuda:"):
46+
device = device.split(":")[0]
47+
4348
compute_type = "default"
4449
if platform.system() == "Darwin":
4550
compute_type = "int8"
4651

47-
device = self._cfg.device
48-
device_index = 0
49-
if self._cfg.device != "cpu":
50-
arr = device.split(":")
51-
device = arr[0]
52-
if len(arr) > 1:
53-
device_index = int(arr[1])
54-
5552
self._model = WhisperModel(
5653
self._cfg.model,
5754
device=device,
58-
device_index=device_index,
5955
cpu_threads=cpu_threads,
6056
compute_type=compute_type,
6157
)

vox_box/cmd/start.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vox_box.config import Config
88
from vox_box.server.model import ModelInstance
99
from vox_box.server.server import Server
10+
from vox_box.utils.model import parse_and_set_cuda_visible_devices
1011

1112

1213
logger = logging.getLogger(__name__)
@@ -91,6 +92,7 @@ def run(args: argparse.Namespace):
9192
try:
9293
cfg = parse_args(args)
9394
setup_logging(cfg.debug)
95+
parse_and_set_cuda_visible_devices(cfg)
9496

9597
logger.info("Starting with arguments: %s", args._get_kwargs())
9698

vox_box/utils/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import logging
2+
import os
13
import time
24
from typing import Dict
35

6+
from vox_box.config import Config
7+
8+
logger = logging.getLogger(__name__)
9+
410

511
def create_model_dict(id: str, **kwargs) -> Dict:
612
d = {
@@ -16,3 +22,16 @@ def create_model_dict(id: str, **kwargs) -> Dict:
1622
d[k] = v
1723

1824
return d
25+
26+
27+
def parse_and_set_cuda_visible_devices(cfg: Config):
28+
"""
29+
Parse CUDA device in format cuda:1 and set CUDA_VISIBLE_DEVICES accordingly.
30+
"""
31+
if cfg.device.startswith("cuda:"):
32+
device_index = cfg.device.split(":")[1]
33+
if device_index.isdigit():
34+
os.environ["CUDA_VISIBLE_DEVICES"] = device_index
35+
logger.info(f"Set CUDA_VISIBLE_DEVICES = {device_index}")
36+
else:
37+
raise ValueError(f"Invalid CUDA device index: {device_index}")

0 commit comments

Comments
 (0)