Skip to content

Commit 3bbfc8e

Browse files
committed
docker file
1 parent 675b635 commit 3bbfc8e

File tree

4 files changed

+855
-0
lines changed

4 files changed

+855
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Base image with Python and CUDA (for H100/A100 GPUs with CUDA 12.1)
2+
FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
3+
4+
5+
# Set environment variables
6+
ENV DEBIAN_FRONTEND=noninteractive
7+
ENV TRANSFORMERS_CACHE=/root/.cache/huggingface/transformers
8+
9+
# Install system packages and cuDNN 9.10.1 manually (for CUDA 12)
10+
RUN apt-get update && apt-get install -y \
11+
python3 python3-pip python3-venv python3-dev \
12+
ffmpeg git curl wget ca-certificates gnupg libsndfile1 && \
13+
mkdir -p /tmp/cudnn && cd /tmp/cudnn && \
14+
wget https://developer.download.nvidia.com/compute/cudnn/9.10.1/local_installers/cudnn-local-repo-ubuntu2204-9.10.1_1.0-1_amd64.deb && \
15+
dpkg -i cudnn-local-repo-ubuntu2204-9.10.1_1.0-1_amd64.deb && \
16+
cp /var/cudnn-local-repo-ubuntu2204-9.10.1/cudnn-*-keyring.gpg /usr/share/keyrings/ && \
17+
apt-get update && \
18+
apt-get -y install cudnn-cuda-12 && \
19+
rm -rf /var/lib/apt/lists/* /tmp/cudnn
20+
21+
# Upgrade pip and install torch separately for CUDA 12.1
22+
RUN python3 -m pip install --upgrade pip && \
23+
pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html
24+
25+
# Copy requirements and install them
26+
COPY requirements.txt /app/requirements.txt
27+
WORKDIR /app
28+
RUN pip install -r requirements.txt
29+
30+
# Copy application code
31+
COPY whisper_code.py whisper_code.py
32+
COPY whisper_api_server.py whisper_api_server.py
33+
34+
# Expose port
35+
EXPOSE 8000
36+
37+
# Run API
38+
CMD ["uvicorn", "whisper_api_server:app", "--host", "0.0.0.0", "--port", "8000"]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Core frameworks
2+
fastapi==0.115.12
3+
uvicorn==0.34.2
4+
5+
# Whisper + speech processing
6+
faster-whisper==1.1.1
7+
librosa==0.11.0
8+
pydub==0.25.1
9+
soundfile==0.13.1
10+
noisereduce==3.0.3
11+
demucs==4.0.1
12+
ffmpeg-python==0.2.0
13+
14+
# Diarization (PyAnnote)
15+
pyannote.audio==3.3.2
16+
pyannote.core==5.0.0
17+
pyannote.database==5.1.3
18+
pyannote.metrics==3.2.1
19+
pyannote.pipeline==3.0.1
20+
21+
# Transformers + Summarization
22+
transformers==4.51.3
23+
huggingface-hub==0.31.4
24+
sentencepiece==0.2.0
25+
26+
# Evaluation
27+
jiwer==3.1.0
28+
29+
# Utilities
30+
numpy==1.26.4
31+
scikit-learn==1.6.1
32+
requests==2.32.3
33+
tqdm==4.67.1
34+
typing_extensions==4.13.2
35+
pydantic==2.11.4
36+
python-multipart==0.0.20
37+
accelerate==1.7.0
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
import tempfile
3+
import json
4+
import threading
5+
import glob
6+
from queue import Queue
7+
from threading import Lock
8+
from fastapi import FastAPI, UploadFile, File, Form
9+
from fastapi.responses import StreamingResponse, JSONResponse
10+
from whisper_code import transcribe_and_summarize, load_whisper_model_faster
11+
12+
app = FastAPI()
13+
14+
# Global model cache and lock
15+
model_cache = {}
16+
model_lock = Lock()
17+
18+
@app.post("/transcribe")
19+
async def transcribe_audio_api(
20+
audio_file: UploadFile = File(...),
21+
model: str = Form("base"),
22+
summarized_model: str = Form("mistralai/Mistral-7B-Instruct-v0.1"),
23+
denoise: bool = Form(False),
24+
prop_decrease: float = Form(0.7),
25+
summary: bool = Form(True),
26+
speaker: bool = Form(False),
27+
hf_token: str = Form(None),
28+
max_speakers: int = Form(None),
29+
streaming: bool = Form(False)
30+
):
31+
temp_audio_path = tempfile.mktemp(suffix=f"_{audio_file.filename}")
32+
with open(temp_audio_path, "wb") as f:
33+
f.write(await audio_file.read())
34+
35+
output_dir = tempfile.mkdtemp()
36+
37+
# Ensure model is loaded once
38+
if model not in model_cache:
39+
model_cache[model] = load_whisper_model_faster(model)
40+
41+
whisper_model = model_cache[model]
42+
43+
if streaming:
44+
def generator():
45+
q = Queue()
46+
47+
def api_callback(result):
48+
q.put(json.dumps(result) + "\n")
49+
50+
def run_pipeline():
51+
try:
52+
with model_lock:
53+
transcribe_and_summarize(
54+
path=temp_audio_path,
55+
model_name=model,
56+
output_dir=output_dir,
57+
summarized_model_id=summarized_model,
58+
denoise=denoise,
59+
prop_decrease=prop_decrease,
60+
summary=summary,
61+
speaker=speaker,
62+
hf_token=hf_token,
63+
max_speakers=max_speakers,
64+
streaming=True,
65+
api_callback=api_callback,
66+
model_instance=whisper_model
67+
)
68+
except Exception as e:
69+
import traceback
70+
traceback.print_exc()
71+
q.put(json.dumps({"error": f"Streaming transcription failed: {str(e)}"}))
72+
finally:
73+
q.put(None)
74+
75+
threading.Thread(target=run_pipeline).start()
76+
77+
while True:
78+
chunk = q.get()
79+
if chunk is None:
80+
break
81+
yield chunk
82+
83+
return StreamingResponse(generator(), media_type="application/json")
84+
85+
else:
86+
try:
87+
with model_lock:
88+
transcribe_and_summarize(
89+
path=temp_audio_path,
90+
model_name=model,
91+
output_dir=output_dir,
92+
summarized_model_id=summarized_model,
93+
denoise=denoise,
94+
prop_decrease=prop_decrease,
95+
summary=summary,
96+
speaker=speaker,
97+
hf_token=hf_token,
98+
max_speakers=max_speakers,
99+
streaming=False,
100+
api_callback=None,
101+
model_instance=whisper_model
102+
)
103+
except Exception as e:
104+
import traceback
105+
traceback.print_exc()
106+
return JSONResponse(
107+
content={"error": f"Transcription failed: {str(e)}"},
108+
status_code=500
109+
)
110+
111+
try:
112+
json_files = sorted(
113+
glob.glob(os.path.join(output_dir, "*.json")),
114+
key=os.path.getmtime,
115+
reverse=True
116+
)
117+
if not json_files:
118+
raise FileNotFoundError("No output JSON file found.")
119+
120+
with open(json_files[0]) as f:
121+
return JSONResponse(content=json.load(f))
122+
123+
except Exception as e:
124+
import traceback
125+
traceback.print_exc()
126+
return JSONResponse(
127+
content={"error": f"Failed to read output JSON: {str(e)}"},
128+
status_code=500
129+
)

0 commit comments

Comments
 (0)