Skip to content

Commit c4c9f59

Browse files
committed
Improved cuda detection for whisper
1 parent 43f1aee commit c4c9f59

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

whisper_mp_worker.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,24 @@ def plog(level, msg):
4343
except Exception:
4444
# Safe fallback: leave i18n defaults; keys may pass through
4545
pass
46+
47+
# determine device
48+
device = args.get("device", "")
49+
if device != 'cpu':
50+
if platform.system() == "Darwin": # MAC
51+
device = 'auto'
52+
elif platform.system() in ('Windows', 'Linux'):
53+
try:
54+
device = 'cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu'
55+
except:
56+
device = 'cpu'
57+
else:
58+
raise Exception('Platform not supported yet.')
4659

4760
# Build model in child using provided options
4861
model = WhisperModel(
4962
args["model_name_or_path"],
50-
device=args.get("device", "auto"),
63+
device=device,
5164
compute_type=args.get("compute_type", "float16"),
5265
cpu_threads=args.get("cpu_threads", 4),
5366
local_files_only=args.get("local_files_only", True),

0 commit comments

Comments
 (0)