File tree Expand file tree Collapse file tree 3 files changed +26
-9
lines changed
Expand file tree Collapse file tree 3 files changed +26
-9
lines changed Original file line number Diff line number Diff line change @@ -28,6 +28,7 @@ def __init__(
2828 preprocessor_config_path = os .path .join (
2929 self ._cfg .model , "preprocessor_config.json"
3030 )
31+
3132 if os .path .exists (preprocessor_config_path ):
3233 with open (preprocessor_config_path , "r" , encoding = "utf-8" ) as f :
3334 self ._preprocessor_config_json = json .load (f )
@@ -40,22 +41,17 @@ def load(self):
4041 if self ._cfg .device == "cpu" :
4142 cpu_threads = 8
4243
44+ device = self ._cfg .device
45+ if device .startswith ("cuda:" ):
46+ device = device .split (":" )[0 ]
47+
4348 compute_type = "default"
4449 if platform .system () == "Darwin" :
4550 compute_type = "int8"
4651
47- device = self ._cfg .device
48- device_index = 0
49- if self ._cfg .device != "cpu" :
50- arr = device .split (":" )
51- device = arr [0 ]
52- if len (arr ) > 1 :
53- device_index = int (arr [1 ])
54-
5552 self ._model = WhisperModel (
5653 self ._cfg .model ,
5754 device = device ,
58- device_index = device_index ,
5955 cpu_threads = cpu_threads ,
6056 compute_type = compute_type ,
6157 )
Original file line number Diff line number Diff line change 77from vox_box .config import Config
88from vox_box .server .model import ModelInstance
99from vox_box .server .server import Server
10+ from vox_box .utils .model import parse_and_set_cuda_visible_devices
1011
1112
1213logger = logging .getLogger (__name__ )
@@ -91,6 +92,7 @@ def run(args: argparse.Namespace):
9192 try :
9293 cfg = parse_args (args )
9394 setup_logging (cfg .debug )
95+ parse_and_set_cuda_visible_devices (cfg )
9496
9597 logger .info ("Starting with arguments: %s" , args ._get_kwargs ())
9698
Original file line number Diff line number Diff line change 1+ import logging
2+ import os
13import time
24from typing import Dict
35
6+ from vox_box .config import Config
7+
8+ logger = logging .getLogger (__name__ )
9+
410
511def create_model_dict (id : str , ** kwargs ) -> Dict :
612 d = {
@@ -16,3 +22,16 @@ def create_model_dict(id: str, **kwargs) -> Dict:
1622 d [k ] = v
1723
1824 return d
25+
26+
27+ def parse_and_set_cuda_visible_devices (cfg : Config ):
28+ """
29+ Parse CUDA device in format cuda:1 and set CUDA_VISIBLE_DEVICES accordingly.
30+ """
31+ if cfg .device .startswith ("cuda:" ):
32+ device_index = cfg .device .split (":" )[1 ]
33+ if device_index .isdigit ():
34+ os .environ ["CUDA_VISIBLE_DEVICES" ] = device_index
35+ logger .info (f"Set CUDA_VISIBLE_DEVICES = { device_index } " )
36+ else :
37+ raise ValueError (f"Invalid CUDA device index: { device_index } " )
You can’t perform that action at this time.
0 commit comments