@@ -8811,6 +8811,12 @@ def whisper_device_select(self, device):
88118811 :return:
88128812 '''
88138813
8814+ allowed_devices = ['cuda', 'CUDA', 'gpu', 'GPU', 'cpu', 'CPU']
8815+
8816+ # change the whisper device if it was passed as a parameter
8817+ if device is not None and device in allowed_devices:
8818+ self.whisper_device = device
8819+
88148820 # if the whisper device is set to cuda
88158821 if self.whisper_device in ['cuda', 'CUDA', 'gpu', 'GPU']:
88168822 # use CUDA if available
@@ -10110,6 +10116,10 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1011010116
1011110117 # what is the name of the audio file
1011210118 audio_file_name = os.path.basename(audio_file_path)
10119+
10120+ whisper_device_changed = False
10121+ if 'device' in other_whisper_options and self.whisper_device != other_whisper_options['device']:
10122+ whisper_device_changed = True
1011310123
1011410124 # select the device that was passed (if any)
1011510125 if 'device' in other_whisper_options:
@@ -10120,7 +10130,8 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1012010130 # load OpenAI Whisper model
1012110131 # and hold it loaded for future use (unless another model was passed via other_whisper_options)
1012210132 if self.whisper_model is None \
10123- or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model']):
10133+ or ('model' in other_whisper_options and self.whisper_model_name != other_whisper_options['model'])\
10134+ or whisper_device_changed:
1012410135
1012510136 # update the status of the item in the transcription log
1012610137 self.update_transcription_log(unique_id=queue_id, **{'status': 'loading model'})
@@ -10147,7 +10158,7 @@ def whisper_transcribe(self, name=None, audio_file_path=None, task=None,
1014710158
1014810159 logger.info('Loading Whisper {} model.'.format(self.whisper_model_name))
1014910160 try:
10150- self.whisper_model = whisper.load_model(self.whisper_model_name)
10161+ self.whisper_model = whisper.load_model(self.whisper_model_name, device=self.whisper_device )
1015110162 except Exception as e:
1015210163 logger.error('Error loading Whisper {} model: {}'.format(self.whisper_model_name, e))
1015310164
0 commit comments