|
4 | 4 | from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QCheckBox, QLabel, QGroupBox, QMessageBox |
5 | 5 | from PySide6.QtCore import Qt |
6 | 6 | import torch |
7 | | -from utilities import get_compute_and_platform_info |
| 7 | +from utilities import get_compute_and_platform_info, get_supported_quantizations |
8 | 8 | from whispers2t_batch_transcriber import Worker |
9 | 9 | from metrics_bar import MetricsBar |
10 | 10 | from settings import SettingsGroupBox |
@@ -42,7 +42,7 @@ def initUI(self): |
42 | 42 |
|
43 | 43 | main_layout = QVBoxLayout() |
44 | 44 |
|
45 | | - transcriberGroupBox = QGroupBox("Batch Transcriber (ctranslate2 edition - huggingface edition coming soon)") |
| 45 | + transcriberGroupBox = QGroupBox("Batch Transcriber (ctranslate2 edition)") |
46 | 46 | transcriberLayout = QVBoxLayout() |
47 | 47 |
|
48 | 48 | self.dirLabel = QLabel("No directory selected") |
@@ -121,10 +121,57 @@ def calculate_files_to_process(self): |
121 | 121 | total_files += len(list(directory_path.glob(pattern))) |
122 | 122 | return total_files |
123 | 123 |
|
| 124 | + def perform_checks(self): |
| 125 | + model = self.settingsGroupBox.modelComboBox.currentText() |
| 126 | + device = self.settingsGroupBox.computeDeviceComboBox.currentText() |
| 127 | + batch_size = self.settingsGroupBox.batchSizeSlider.value() |
| 128 | + beam_size = self.settingsGroupBox.beamSizeSlider.value() |
| 129 | + |
| 130 | + # Check 1: CPU and non-float32 model |
| 131 | + if "float32" not in model.lower() and device.lower() == "cpu": |
| 132 | + QMessageBox.warning(self, "Invalid Configuration", |
| 133 | + "CPU only supports Float 32 computation. Please select a different Whisper model.") |
| 134 | + return False |
| 135 | + |
| 136 | + # Check 2: CPU with high batch size |
| 137 | + if device.lower() == "cpu" and batch_size > 16: |
| 138 | + reply = QMessageBox.warning(self, "Performance Warning", |
| 139 | + "When using CPU it is generally recommended to use a batch size of no more than 16 " |
| 140 | + "otherwise compute time will actually be worse.\n\n" |
| 141 | + "Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n\n" |
| 142 | + "For example:\n" |
| 143 | + "- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n" |
| 144 | + "- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n\nClick OK to proceed.", |
| 145 | + QMessageBox.Ok | QMessageBox.Cancel, |
| 146 | + QMessageBox.Cancel) |
| 147 | + if reply == QMessageBox.Cancel: |
| 148 | + return False |
| 149 | + |
| 150 | + # Check 3: GPU compatibility |
| 151 | + # Only perform this check if the device is not CPU |
| 152 | + if device.lower() != "cpu": |
| 153 | + supported_quantizations = get_supported_quantizations(device) |
| 154 | + if "float16" in model.lower() and "float16" not in supported_quantizations: |
| 155 | + QMessageBox.warning(self, "Incompatible Configuration", |
| 156 | + "Your GPU does not support the selected floating point value (float16). " |
| 157 | + "Please make another selection.") |
| 158 | + return False |
| 159 | + if "bfloat16" in model.lower() and "bfloat16" not in supported_quantizations: |
| 160 | + QMessageBox.warning(self, "Incompatible Configuration", |
| 161 | + "Your GPU does not support the selected floating point value (bfloat16). " |
| 162 | + "Please make another selection.") |
| 163 | + return False |
| 164 | + |
| 165 | + return True # All checks passed |
| 166 | + |
124 | 167 | def processFiles(self): |
125 | 168 | if hasattr(self, 'directory'): |
126 | 169 | total_files = self.calculate_files_to_process() |
127 | 170 |
|
| 171 | + # Perform checks |
| 172 | + if not self.perform_checks(): |
| 173 | + return |
| 174 | + |
128 | 175 | reply = QMessageBox.question( |
129 | 176 | self, |
130 | 177 | 'Confirm Process', |
|
0 commit comments