Skip to content

Commit d1905bd

Browse files
authored
change combobox logic
1 parent 95d5db8 commit d1905bd

File tree

4 files changed

+85
-59
lines changed

4 files changed

+85
-59
lines changed

settings.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from PySide6.QtWidgets import (QGroupBox, QVBoxLayout, QHBoxLayout, QLabel, QComboBox, QSlider)
2-
from PySide6.QtCore import Qt
2+
from PySide6.QtCore import Qt, Signal
33
from constants import WHISPER_MODELS
4+
import torch
5+
6+
from utilities import has_bfloat16_support
47

58
class SettingsGroupBox(QGroupBox):
9+
device_changed = Signal(str)
10+
611
def __init__(self, get_compute_and_platform_info_callback, parent=None):
712
super().__init__("Settings", parent)
813
self.get_compute_and_platform_info = get_compute_and_platform_info_callback
@@ -17,14 +22,14 @@ def initUI(self):
1722
hbox1_layout.addWidget(modelLabel)
1823

1924
self.modelComboBox = QComboBox()
20-
self.modelComboBox.addItems(WHISPER_MODELS.keys())
2125
hbox1_layout.addWidget(self.modelComboBox)
2226

2327
computeDeviceLabel = QLabel("Device:")
2428
hbox1_layout.addWidget(computeDeviceLabel)
2529

2630
self.computeDeviceComboBox = QComboBox()
2731
hbox1_layout.addWidget(self.computeDeviceComboBox)
32+
self.computeDeviceComboBox.currentTextChanged.connect(self.on_device_changed)
2833

2934
formatLabel = QLabel("Output:")
3035
hbox1_layout.addWidget(formatLabel)
@@ -67,7 +72,7 @@ def initUI(self):
6772
self.batchSizeSlider = QSlider(Qt.Horizontal)
6873
self.batchSizeSlider.setMinimum(1)
6974
self.batchSizeSlider.setMaximum(200)
70-
self.batchSizeSlider.setValue(16)
75+
self.batchSizeSlider.setValue(8)
7176
self.batchSizeSlider.setTickPosition(QSlider.TicksBelow)
7277
self.batchSizeSlider.setTickInterval(10)
7378
batch_size_layout.addWidget(self.batchSizeSlider)
@@ -86,4 +91,25 @@ def update_slider_label(self, slider, label):
8691
def populateComputeDeviceComboBox(self):
8792
available_devices = self.get_compute_and_platform_info()
8893
self.computeDeviceComboBox.addItems(available_devices)
89-
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu"))
94+
if "cuda" in available_devices:
95+
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cuda"))
96+
else:
97+
self.computeDeviceComboBox.setCurrentIndex(self.computeDeviceComboBox.findText("cpu"))
98+
self.update_model_combobox()
99+
100+
def on_device_changed(self, device):
101+
self.device_changed.emit(device)
102+
self.update_model_combobox()
103+
104+
def update_model_combobox(self):
105+
current_device = self.computeDeviceComboBox.currentText()
106+
self.modelComboBox.clear()
107+
108+
for model_name, model_info in WHISPER_MODELS.items():
109+
if current_device == "cpu" and model_info['precision'] == 'float32':
110+
self.modelComboBox.addItem(model_name)
111+
elif current_device == "cuda":
112+
if model_info['precision'] in ['float32', 'float16']:
113+
self.modelComboBox.addItem(model_name)
114+
elif model_info['precision'] == 'bfloat16' and has_bfloat16_support():
115+
self.modelComboBox.addItem(model_name)

utilities.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@ def get_compute_and_platform_info():
1111
return available_devices
1212

1313

14-
def get_supported_quantizations(device_type):
15-
types = ctranslate2.get_supported_compute_types(device_type)
16-
filtered_types = [q for q in types if q != 'int16']
17-
desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']
18-
sorted_types = [q for q in desired_order if q in filtered_types]
19-
return sorted_types
14+
# def get_supported_quantizations(device_type):
15+
# types = ctranslate2.get_supported_compute_types(device_type)
16+
# filtered_types = [q for q in types if q != 'int16']
17+
# desired_order = ['float32', 'float16', 'bfloat16', 'int8_float32', 'int8_float16', 'int8_bfloat16', 'int8']
18+
# sorted_types = [q for q in desired_order if q in filtered_types]
19+
# return sorted_types
2020

21+
def get_logical_core_count():
22+
return psutil.cpu_count(logical=True)
2123

22-
def get_physical_core_count():
23-
return psutil.cpu_count(logical=False)
24+
def has_bfloat16_support():
25+
if not torch.cuda.is_available():
26+
return False
27+
28+
capability = torch.cuda.get_device_capability()
29+
return capability >= (8, 6)

whispers2t_batch_gui.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,29 @@
1+
import logging
12
import os
23
import sys
4+
import traceback
35
from pathlib import Path
4-
from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QCheckBox, QLabel, QGroupBox, QMessageBox
6+
7+
import torch
58
from PySide6.QtCore import Qt
6-
import torch
7-
from utilities import get_compute_and_platform_info, get_supported_quantizations
8-
from whispers2t_batch_transcriber import Worker
9+
from PySide6.QtWidgets import (
10+
QApplication,
11+
QCheckBox,
12+
QFileDialog,
13+
QGroupBox,
14+
QHBoxLayout,
15+
QLabel,
16+
QMessageBox,
17+
QPushButton,
18+
QVBoxLayout,
19+
QWidget,
20+
)
21+
22+
from constants import WHISPER_MODELS
923
from metrics_bar import MetricsBar
1024
from settings import SettingsGroupBox
11-
import logging
12-
import traceback
13-
from constants import WHISPER_MODELS
25+
from utilities import has_bfloat16_support
26+
from whispers2t_batch_transcriber import Worker
1427

1528
def set_cuda_paths():
1629
try:
@@ -26,18 +39,14 @@ def set_cuda_paths():
2639

2740
set_cuda_paths()
2841

29-
def is_nvidia_gpu_available():
30-
return torch.cuda.is_available() and "nvidia" in torch.cuda.get_device_name(0).lower()
31-
3242
class MainWindow(QWidget):
3343
def __init__(self):
3444
super().__init__()
3545
self.initUI()
3646

3747
def initUI(self):
3848
self.setWindowTitle("chintellalaw.com - for non-commercial use")
39-
initial_height = 400 if is_nvidia_gpu_available() else 370
40-
self.setGeometry(100, 100, 680, initial_height)
49+
self.setGeometry(100, 100, 680, 400)
4150
self.setWindowFlags(self.windowFlags() | Qt.WindowStaysOnTopHint)
4251

4352
main_layout = QVBoxLayout()
@@ -69,7 +78,8 @@ def initUI(self):
6978
fileExtensionsGroupBox.setLayout(fileExtensionsLayout)
7079
main_layout.addWidget(fileExtensionsGroupBox)
7180

72-
self.settingsGroupBox = SettingsGroupBox(get_compute_and_platform_info, self)
81+
self.settingsGroupBox = SettingsGroupBox(self.get_compute_and_platform_info, self)
82+
self.settingsGroupBox.device_changed.connect(self.on_device_changed)
7383
main_layout.addWidget(self.settingsGroupBox)
7484

7585
selectDirLayout = QHBoxLayout()
@@ -101,6 +111,16 @@ def closeEvent(self, event):
101111
self.metricsBar.stop_metrics_collector()
102112
super().closeEvent(event)
103113

114+
def get_compute_and_platform_info(self):
115+
devices = ["cpu"]
116+
if torch.cuda.is_available():
117+
devices.append("cuda")
118+
return devices
119+
120+
def on_device_changed(self, device):
121+
# You can add any additional logic here if needed when the device is changed
122+
pass
123+
104124
def selectDirectory(self):
105125
dirPath = QFileDialog.getExistingDirectory(self, "Select Directory")
106126
if dirPath:
@@ -122,46 +142,20 @@ def calculate_files_to_process(self):
122142
return total_files
123143

124144
def perform_checks(self):
125-
model = self.settingsGroupBox.modelComboBox.currentText()
126145
device = self.settingsGroupBox.computeDeviceComboBox.currentText()
127146
batch_size = self.settingsGroupBox.batchSizeSlider.value()
128147
beam_size = self.settingsGroupBox.beamSizeSlider.value()
129148

130-
# Check 1: CPU and non-float32 model
131-
if "float32" not in model.lower() and device.lower() == "cpu":
132-
QMessageBox.warning(self, "Invalid Configuration",
133-
"CPU only supports Float 32 computation. Please select a different Whisper model.")
134-
return False
135-
136-
# Check 2: CPU with high batch size
137-
if device.lower() == "cpu" and batch_size > 16:
138-
reply = QMessageBox.warning(self, "Performance Warning",
139-
"When using CPU it is generally recommended to use a batch size of no more than 16 "
140-
"otherwise compute time will actually be worse.\n\n"
141-
"Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n\n"
142-
"For example:\n"
143-
"- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n"
144-
"- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n\nClick OK to proceed.",
149+
# Check: CPU with high batch size
150+
if device.lower() == "cpu" and batch_size > 8:
151+
reply = QMessageBox.warning(self, "Warning",
152+
"When using CPU it is generally recommended to use a batch size of no more than 8 "
153+
"otherwise compute could actually be worse. Use at your own risk.",
145154
QMessageBox.Ok | QMessageBox.Cancel,
146155
QMessageBox.Cancel)
147156
if reply == QMessageBox.Cancel:
148157
return False
149158

150-
# Check 3: GPU compatibility
151-
# Only perform this check if the device is not CPU
152-
if device.lower() != "cpu":
153-
supported_quantizations = get_supported_quantizations(device)
154-
if "float16" in model.lower() and "float16" not in supported_quantizations:
155-
QMessageBox.warning(self, "Incompatible Configuration",
156-
"Your GPU does not support the selected floating point value (float16). "
157-
"Please make another selection.")
158-
return False
159-
if "bfloat16" in model.lower() and "bfloat16" not in supported_quantizations:
160-
QMessageBox.warning(self, "Incompatible Configuration",
161-
"Your GPU does not support the selected floating point value (bfloat16). "
162-
"Please make another selection.")
163-
return False
164-
165159
return True # All checks passed
166160

167161
def processFiles(self):
@@ -227,4 +221,4 @@ def workerFinished(self, message):
227221
app.setStyle("Fusion")
228222
mainWindow = MainWindow()
229223
mainWindow.show()
230-
sys.exit(app.exec())
224+
sys.exit(app.exec())

whispers2t_batch_transcriber.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import torch
1111

1212
from constants import WHISPER_MODELS
13-
from utilities import get_physical_core_count
13+
from utilities import get_logical_core_count
1414

15-
CPU_THREADS = max(4, get_physical_core_count() - 1)
15+
CPU_THREADS = max(4, get_logical_core_count() - 8)
1616

1717
class Worker(QThread):
1818
finished = Signal(str)

0 commit comments

Comments
 (0)