Skip to content

Commit 95d5db8

Browse files
authored
Add files via upload
1 parent 31b3a85 commit 95d5db8

File tree

4 files changed

+66
-11
lines changed

4 files changed

+66
-11
lines changed

settings.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ def initUI(self):
1313

1414
hbox1_layout = QHBoxLayout()
1515

16-
# Replace size and quantization combo boxes with a single model combo box
1716
modelLabel = QLabel("Model:")
1817
hbox1_layout.addWidget(modelLabel)
1918

2019
self.modelComboBox = QComboBox()
2120
self.modelComboBox.addItems(WHISPER_MODELS.keys())
2221
hbox1_layout.addWidget(self.modelComboBox)
2322

24-
# Keep the rest of the widgets
2523
computeDeviceLabel = QLabel("Device:")
2624
hbox1_layout.addWidget(computeDeviceLabel)
2725

@@ -50,13 +48,13 @@ def initUI(self):
5048

5149
self.beamSizeSlider = QSlider(Qt.Horizontal)
5250
self.beamSizeSlider.setMinimum(1)
53-
self.beamSizeSlider.setMaximum(10)
54-
self.beamSizeSlider.setValue(5)
51+
self.beamSizeSlider.setMaximum(5)
52+
self.beamSizeSlider.setValue(1)
5553
self.beamSizeSlider.setTickPosition(QSlider.TicksBelow)
5654
self.beamSizeSlider.setTickInterval(1)
5755
beam_size_layout.addWidget(self.beamSizeSlider)
5856

59-
self.beamSizeValueLabel = QLabel("5")
57+
self.beamSizeValueLabel = QLabel("1")
6058
beam_size_layout.addWidget(self.beamSizeValueLabel)
6159
self.beamSizeSlider.valueChanged.connect(lambda: self.update_slider_label(self.beamSizeSlider, self.beamSizeValueLabel))
6260

utilities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import ctranslate2
3+
import psutil
34

45
def get_compute_and_platform_info():
56
available_devices = ["cpu"]
@@ -9,10 +10,14 @@ def get_compute_and_platform_info():
910

1011
return available_devices
1112

13+
1214
def get_supported_quantizations(device_type):
1315
types = ctranslate2.get_supported_compute_types(device_type)
1416
filtered_types = [q for q in types if q != 'int16']
1517
desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']
1618
sorted_types = [q for q in desired_order if q in filtered_types]
1719
return sorted_types
1820

21+
22+
def get_physical_core_count():
23+
return psutil.cpu_count(logical=False)

whispers2t_batch_gui.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QCheckBox, QLabel, QGroupBox, QMessageBox
55
from PySide6.QtCore import Qt
66
import torch
7-
from utilities import get_compute_and_platform_info
7+
from utilities import get_compute_and_platform_info, get_supported_quantizations
88
from whispers2t_batch_transcriber import Worker
99
from metrics_bar import MetricsBar
1010
from settings import SettingsGroupBox
@@ -42,7 +42,7 @@ def initUI(self):
4242

4343
main_layout = QVBoxLayout()
4444

45-
transcriberGroupBox = QGroupBox("Batch Transcriber (ctranslate2 edition - huggingface edition coming soon)")
45+
transcriberGroupBox = QGroupBox("Batch Transcriber (ctranslate2 edition)")
4646
transcriberLayout = QVBoxLayout()
4747

4848
self.dirLabel = QLabel("No directory selected")
@@ -121,10 +121,57 @@ def calculate_files_to_process(self):
121121
total_files += len(list(directory_path.glob(pattern)))
122122
return total_files
123123

124+
def perform_checks(self):
125+
model = self.settingsGroupBox.modelComboBox.currentText()
126+
device = self.settingsGroupBox.computeDeviceComboBox.currentText()
127+
batch_size = self.settingsGroupBox.batchSizeSlider.value()
128+
beam_size = self.settingsGroupBox.beamSizeSlider.value()
129+
130+
# Check 1: CPU and non-float32 model
131+
if "float32" not in model.lower() and device.lower() == "cpu":
132+
QMessageBox.warning(self, "Invalid Configuration",
133+
"CPU only supports Float 32 computation. Please select a different Whisper model.")
134+
return False
135+
136+
# Check 2: CPU with high batch size
137+
if device.lower() == "cpu" and batch_size > 16:
138+
reply = QMessageBox.warning(self, "Performance Warning",
139+
"When using CPU it is generally recommended to use a batch size of no more than 16 "
140+
"otherwise compute time will actually be worse.\n\n"
141+
"Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n\n"
142+
"For example:\n"
143+
"- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n"
144+
"- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n\nClick OK to proceed.",
145+
QMessageBox.Ok | QMessageBox.Cancel,
146+
QMessageBox.Cancel)
147+
if reply == QMessageBox.Cancel:
148+
return False
149+
150+
# Check 3: GPU compatibility
151+
# Only perform this check if the device is not CPU
152+
if device.lower() != "cpu":
153+
supported_quantizations = get_supported_quantizations(device)
154+
if "float16" in model.lower() and "float16" not in supported_quantizations:
155+
QMessageBox.warning(self, "Incompatible Configuration",
156+
"Your GPU does not support the selected floating point value (float16). "
157+
"Please make another selection.")
158+
return False
159+
if "bfloat16" in model.lower() and "bfloat16" not in supported_quantizations:
160+
QMessageBox.warning(self, "Incompatible Configuration",
161+
"Your GPU does not support the selected floating point value (bfloat16). "
162+
"Please make another selection.")
163+
return False
164+
165+
return True # All checks passed
166+
124167
def processFiles(self):
125168
if hasattr(self, 'directory'):
126169
total_files = self.calculate_files_to_process()
127170

171+
# Perform checks
172+
if not self.perform_checks():
173+
return
174+
128175
reply = QMessageBox.question(
129176
self,
130177
'Confirm Process',

whispers2t_batch_transcriber.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import os
22
import sys
33
import gc
4-
from PySide6.QtCore import QThread, Signal, QElapsedTimer
54
from pathlib import Path
6-
import whisper_s2t
5+
from threading import Event
76
from queue import Queue
7+
8+
from PySide6.QtCore import QThread, Signal, QElapsedTimer
9+
import whisper_s2t
810
import torch
9-
from threading import Event
11+
1012
from constants import WHISPER_MODELS
13+
from utilities import get_physical_core_count
14+
15+
CPU_THREADS = max(4, get_physical_core_count() - 1)
1116

1217
class Worker(QThread):
1318
finished = Signal(str)
@@ -66,7 +71,7 @@ def run(self):
6671
device=self.device,
6772
compute_type=self.model_info['precision'],
6873
asr_options={'beam_size': self.beam_size},
69-
cpu_threads=max(4, os.cpu_count() - 6) if self.device == "cpu" else 4,
74+
cpu_threads=CPU_THREADS if self.device == "cpu" else 4,
7075
**model_kwargs
7176
)
7277

0 commit comments

Comments
 (0)