1+ import logging
12import os
23import sys
4+ import traceback
35from pathlib import Path
4- from PySide6 .QtWidgets import QApplication , QWidget , QVBoxLayout , QHBoxLayout , QPushButton , QFileDialog , QCheckBox , QLabel , QGroupBox , QMessageBox
6+
7+ import torch
58from PySide6 .QtCore import Qt
6- import torch
7- from utilities import get_compute_and_platform_info , get_supported_quantizations
8- from whispers2t_batch_transcriber import Worker
9+ from PySide6 .QtWidgets import (
10+ QApplication ,
11+ QCheckBox ,
12+ QFileDialog ,
13+ QGroupBox ,
14+ QHBoxLayout ,
15+ QLabel ,
16+ QMessageBox ,
17+ QPushButton ,
18+ QVBoxLayout ,
19+ QWidget ,
20+ )
21+
22+ from constants import WHISPER_MODELS
923from metrics_bar import MetricsBar
1024from settings import SettingsGroupBox
11- import logging
12- import traceback
13- from constants import WHISPER_MODELS
25+ from utilities import has_bfloat16_support
26+ from whispers2t_batch_transcriber import Worker
1427
1528def set_cuda_paths ():
1629 try :
@@ -26,18 +39,14 @@ def set_cuda_paths():
2639
2740set_cuda_paths ()
2841
29- def is_nvidia_gpu_available ():
30- return torch .cuda .is_available () and "nvidia" in torch .cuda .get_device_name (0 ).lower ()
31-
3242class MainWindow (QWidget ):
3343 def __init__ (self ):
3444 super ().__init__ ()
3545 self .initUI ()
3646
3747 def initUI (self ):
3848 self .setWindowTitle ("chintellalaw.com - for non-commercial use" )
39- initial_height = 400 if is_nvidia_gpu_available () else 370
40- self .setGeometry (100 , 100 , 680 , initial_height )
49+ self .setGeometry (100 , 100 , 680 , 400 )
4150 self .setWindowFlags (self .windowFlags () | Qt .WindowStaysOnTopHint )
4251
4352 main_layout = QVBoxLayout ()
@@ -69,7 +78,8 @@ def initUI(self):
6978 fileExtensionsGroupBox .setLayout (fileExtensionsLayout )
7079 main_layout .addWidget (fileExtensionsGroupBox )
7180
72- self .settingsGroupBox = SettingsGroupBox (get_compute_and_platform_info , self )
81+ self .settingsGroupBox = SettingsGroupBox (self .get_compute_and_platform_info , self )
82+ self .settingsGroupBox .device_changed .connect (self .on_device_changed )
7383 main_layout .addWidget (self .settingsGroupBox )
7484
7585 selectDirLayout = QHBoxLayout ()
@@ -101,6 +111,16 @@ def closeEvent(self, event):
101111 self .metricsBar .stop_metrics_collector ()
102112 super ().closeEvent (event )
103113
114+ def get_compute_and_platform_info (self ):
115+ devices = ["cpu" ]
116+ if torch .cuda .is_available ():
117+ devices .append ("cuda" )
118+ return devices
119+
120+ def on_device_changed (self , device ):
121+ # You can add any additional logic here if needed when the device is changed
122+ pass
123+
104124 def selectDirectory (self ):
105125 dirPath = QFileDialog .getExistingDirectory (self , "Select Directory" )
106126 if dirPath :
@@ -122,46 +142,20 @@ def calculate_files_to_process(self):
122142 return total_files
123143
124144 def perform_checks (self ):
125- model = self .settingsGroupBox .modelComboBox .currentText ()
126145 device = self .settingsGroupBox .computeDeviceComboBox .currentText ()
127146 batch_size = self .settingsGroupBox .batchSizeSlider .value ()
128147 beam_size = self .settingsGroupBox .beamSizeSlider .value ()
129148
130- # Check 1: CPU and non-float32 model
131- if "float32" not in model .lower () and device .lower () == "cpu" :
132- QMessageBox .warning (self , "Invalid Configuration" ,
133- "CPU only supports Float 32 computation. Please select a different Whisper model." )
134- return False
135-
136- # Check 2: CPU with high batch size
137- if device .lower () == "cpu" and batch_size > 16 :
138- reply = QMessageBox .warning (self , "Performance Warning" ,
139- "When using CPU it is generally recommended to use a batch size of no more than 16 "
140- "otherwise compute time will actually be worse.\n \n "
141- "Moreover, if you select a Beam Size greater than one, you should reduce the Batch Size accordingly.\n \n "
142- "For example:\n "
143- "- If you select a Beam Size of 2 (double the default value of 1) you would reduce the Batch Size (default value 16) by half.\n "
144- "- If Beam Size is set to 3 you should reduce the Batch Size to 1/3 of the default level, and so on.\n \n Click OK to proceed." ,
149+ # Check: CPU with high batch size
150+ if device .lower () == "cpu" and batch_size > 8 :
151+ reply = QMessageBox .warning (self , "Warning" ,
152+ "When using CPU it is generally recommended to use a batch size of no more than 8 "
153+ "otherwise compute could actually be worse. Use at your own risk." ,
145154 QMessageBox .Ok | QMessageBox .Cancel ,
146155 QMessageBox .Cancel )
147156 if reply == QMessageBox .Cancel :
148157 return False
149158
150- # Check 3: GPU compatibility
151- # Only perform this check if the device is not CPU
152- if device .lower () != "cpu" :
153- supported_quantizations = get_supported_quantizations (device )
154- if "float16" in model .lower () and "float16" not in supported_quantizations :
155- QMessageBox .warning (self , "Incompatible Configuration" ,
156- "Your GPU does not support the selected floating point value (float16). "
157- "Please make another selection." )
158- return False
159- if "bfloat16" in model .lower () and "bfloat16" not in supported_quantizations :
160- QMessageBox .warning (self , "Incompatible Configuration" ,
161- "Your GPU does not support the selected floating point value (bfloat16). "
162- "Please make another selection." )
163- return False
164-
165159 return True # All checks passed
166160
167161 def processFiles (self ):
@@ -227,4 +221,4 @@ def workerFinished(self, message):
227221 app .setStyle ("Fusion" )
228222 mainWindow = MainWindow ()
229223 mainWindow .show ()
230- sys .exit (app .exec ())
224+ sys .exit (app .exec ())
0 commit comments