|
| 1 | +import torch |
| 2 | +import torchaudio |
| 3 | +import whisper |
| 4 | +from request import ModelRequest |
| 5 | +import tempfile |
| 6 | +import os |
| 7 | + |
| 8 | +class Model(): |
| 9 | + def __new__(cls, context): |
| 10 | + cls.context = context |
| 11 | + if not hasattr(cls, 'instance'): |
| 12 | + cls.instance = super(Model, cls).__new__(cls) |
| 13 | + |
| 14 | + # Load Whisper model |
| 15 | + cls.model = whisper.load_model("base") |
| 16 | + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 17 | + cls.model.to(cls.device) |
| 18 | + return cls.instance |
| 19 | + |
| 20 | + def trim_audio(self, audio_path, n_seconds): |
| 21 | + audio, sr = torchaudio.load(audio_path) |
| 22 | + total_duration = audio.shape[1] / sr # Total duration of the audio in seconds |
| 23 | + |
| 24 | + # If the audio duration is less than n_seconds, don't trim the audio |
| 25 | + if total_duration < n_seconds: |
| 26 | + print(f"The audio duration ({total_duration:.2f}s) is less than {n_seconds}s. Using the full audio.") |
| 27 | + return audio, sr |
| 28 | + |
| 29 | + num_samples = int(n_seconds * sr) |
| 30 | + audio = audio[:, :num_samples] |
| 31 | + return audio, sr |
| 32 | + |
| 33 | + async def inference(self, request: ModelRequest): |
| 34 | + # The n_seconds is now accessed from the request object |
| 35 | + n_seconds = request.n_seconds |
| 36 | + trimmed_audio, sr = self.trim_audio(request.wav_file, n_seconds) |
| 37 | + |
| 38 | + # Save the trimmed audio to a temporary file |
| 39 | + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: # Add a file extension |
| 40 | + torchaudio.save(temp_file.name, trimmed_audio, sr) |
| 41 | + |
| 42 | + # Process the audio with Whisper |
| 43 | + audio = whisper.load_audio(temp_file.name) |
| 44 | + audio = whisper.pad_or_trim(audio) |
| 45 | + |
| 46 | + # Clean up the temporary file |
| 47 | + os.unlink(temp_file.name) |
| 48 | + |
| 49 | + mel = whisper.log_mel_spectrogram(audio).to(self.device) |
| 50 | + # Detect the spoken language |
| 51 | + _, probs = self.model.detect_language(mel) |
| 52 | + detected_language = max(probs, key=probs.get) |
| 53 | + |
| 54 | + return detected_language |
| 55 | + |
| 56 | + |
0 commit comments