Skip to content

Commit 4081e34

Browse files
authored
Fixed CPU unusable on CUDA machines
1 parent 4885e4b commit 4081e34

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

app.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8811,6 +8811,12 @@ def whisper_device_select(self, device):
88118811
:return:
88128812
'''
88138813

8814+
allowed_devices = ['cuda', 'CUDA', 'gpu', 'GPU', 'cpu', 'CPU']
8815+
8816+
# change the whisper device if it was passed as a parameter
8817+
if device is not None and device in allowed_devices:
8818+
self.whisper_device = device
8819+
88148820
# if the whisper device is set to cuda
88158821
if self.whisper_device in ['cuda', 'CUDA', 'gpu', 'GPU']:
88168822
# use CUDA if available
@@ -10110,6 +10116,10 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1011010116

1011110117
# what is the name of the audio file
1011210118
audio_file_name = os.path.basename(audio_file_path)
10119+
10120+
whisper_device_changed = False
10121+
if 'device' in other_whisper_options and self.whisper_device != other_whisper_options['device']:
10122+
whisper_device_changed = True
1011310123

1011410124
# select the device that was passed (if any)
1011510125
if 'device' in other_whisper_options:
@@ -10120,7 +10130,8 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1012010130
# load OpenAI Whisper model
1012110131
# and hold it loaded for future use (unless another model was passed via other_whisper_options)
1012210132
if self.whisper_model is None \
10123-
or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model']):
10133+
or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model'])\
10134+
or whisper_device_changed:
1012410135

1012510136
# update the status of the item in the transcription log
1012610137
self.update_transcription_log(unique_id=queue_id, **{'status': 'loading model'})
@@ -10147,7 +10158,7 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1014710158

1014810159
logger.info('Loading Whisper {} model.'.format(self.whisper_model_name))
1014910160
try:
10150-
self.whisper_model = whisper.load_model(self.whisper_model_name)
10161+
self.whisper_model = whisper.load_model(self.whisper_model_name, device=self.whisper_device)
1015110162
except Exception as e:
1015210163
logger.error('Error loading Whisper {} model: {}'.format(self.whisper_model_name, e))
1015310164

0 commit comments

Comments
 (0)