Skip to content

Commit 03da8cd

Browse files
Merge pull request Samagra-Development#262 from Samagra-Development/lang_detect
Added lang detect
2 parents 22f0de3 + 6a0b346 commit 03da8cd

File tree

7 files changed

+155
-0
lines changed

7 files changed

+155
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Use an official Python runtime as a parent image
2+
FROM python:3.9-slim
3+
4+
WORKDIR /app
5+
6+
# Install requirements
7+
COPY requirements.txt requirements.txt
8+
RUN pip3 install -r requirements.txt
9+
10+
# Update aptitude with new repo info, and install FFmpeg
11+
RUN apt-get update \
12+
&& apt-get install -y ffmpeg \
13+
&& apt-get clean \
14+
&& rm -rf /var/lib/apt/lists/*
15+
16+
# Copy the rest of the application code to the working directory
17+
COPY . /app/
18+
EXPOSE 8000
19+
20+
# Set the entrypoint for the container
21+
CMD ["hypercorn", "--bind", "0.0.0.0:8000", "api:app"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
### Testing the model deployment :
2+
To run for testing you can follow the following steps :
3+
4+
- Git clone the repo
5+
- Go to current folder location i.e. ``` cd /src/asr/whisper_lang_rec/local ```
6+
- Create docker image file and test the api:
7+
```
8+
docker build -t testmodel .
9+
docker run -p 8000:8000 testmodel
10+
curl -X POST -F "[email protected]" -F "n_seconds=5" http://localhost:8000/
11+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .request import ModelRequest
2+
from .request import Model
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from model import Model
2+
from request import ModelRequest
3+
from quart import Quart, request
4+
from quart_cors import cors # Import the cors function
5+
import aiohttp
6+
import os
7+
import tempfile
8+
import os
9+
10+
11+
app = Quart(__name__)
12+
app = cors(app) # Apply the cors function to your app to enable CORS for all routes
13+
14+
model = None
15+
16+
@app.before_serving
17+
async def startup():
18+
app.client = aiohttp.ClientSession()
19+
global model
20+
model = Model(app)
21+
22+
@app.route('/', methods=['POST'])
23+
async def embed():
24+
global model
25+
26+
temp_dir = tempfile.mkdtemp()
27+
data = await request.form
28+
files = await request.files
29+
uploaded_file = files.get('file')
30+
31+
file_path = os.path.join(temp_dir, uploaded_file.filename)
32+
await uploaded_file.save(file_path)
33+
34+
n_seconds = int(data.get('n_seconds'))
35+
req = ModelRequest(wav_file=file_path, n_seconds=n_seconds)
36+
response = await model.inference(req) # Removed n_seconds here
37+
38+
os.remove(file_path)
39+
os.rmdir(temp_dir)
40+
41+
return response
42+
43+
44+
if __name__ == "__main__":
45+
app.run()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torchaudio
3+
import whisper
4+
from request import ModelRequest
5+
import tempfile
6+
import os
7+
8+
class Model():
9+
def __new__(cls, context):
10+
cls.context = context
11+
if not hasattr(cls, 'instance'):
12+
cls.instance = super(Model, cls).__new__(cls)
13+
14+
# Load Whisper model
15+
cls.model = whisper.load_model("base")
16+
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
cls.model.to(cls.device)
18+
return cls.instance
19+
20+
def trim_audio(self, audio_path, n_seconds):
21+
audio, sr = torchaudio.load(audio_path)
22+
total_duration = audio.shape[1] / sr # Total duration of the audio in seconds
23+
24+
# If the audio duration is less than n_seconds, don't trim the audio
25+
if total_duration < n_seconds:
26+
print(f"The audio duration ({total_duration:.2f}s) is less than {n_seconds}s. Using the full audio.")
27+
return audio, sr
28+
29+
num_samples = int(n_seconds * sr)
30+
audio = audio[:, :num_samples]
31+
return audio, sr
32+
33+
async def inference(self, request: ModelRequest):
34+
# The n_seconds is now accessed from the request object
35+
n_seconds = request.n_seconds
36+
trimmed_audio, sr = self.trim_audio(request.wav_file, n_seconds)
37+
38+
# Save the trimmed audio to a temporary file
39+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: # Add a file extension
40+
torchaudio.save(temp_file.name, trimmed_audio, sr)
41+
42+
# Process the audio with Whisper
43+
audio = whisper.load_audio(temp_file.name)
44+
audio = whisper.pad_or_trim(audio)
45+
46+
# Clean up the temporary file
47+
os.unlink(temp_file.name)
48+
49+
mel = whisper.log_mel_spectrogram(audio).to(self.device)
50+
# Detect the spoken language
51+
_, probs = self.model.detect_language(mel)
52+
detected_language = max(probs, key=probs.get)
53+
54+
return detected_language
55+
56+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import requests
2+
import json
3+
4+
5+
class ModelRequest():
6+
def __init__(self, wav_file,n_seconds):
7+
self.wav_file = wav_file
8+
self.n_seconds = n_seconds
9+
10+
def to_json(self):
11+
return json.dumps(self, default=lambda o: o.__dict__,
12+
sort_keys=True, indent=4)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
torch
2+
torchaudio
3+
transformers
4+
quart
5+
aiohttp
6+
librosa
7+
quart-cors
8+
openai-whisper

0 commit comments

Comments
 (0)