Skip to content

Commit a56c1f1

Browse files
committed
Added STT model for audio context.
1 parent d59b55d commit a56c1f1

File tree

11 files changed

+222
-21
lines changed

11 files changed

+222
-21
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ faiss_index.bin
66
input_video.mp4
77
metadata.pkl
88
.env
9+
data

backend/app/api/routes.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from app.core.memory import ChatMemory
55
from app.services.extract_frames import extract_frames
66
from app.services.faiss import faiss_process
7+
from app.services.extract_audio import AudioExtractor
8+
from app.services.transcription import get_vosk_model
9+
from app.services.transcribe import Transcriber
710

811
router = APIRouter()
912
memory = ChatMemory()
@@ -28,8 +31,17 @@ async def clip_embed_video(file: UploadFile = File(...)):
2831
frames = extract_frames(temp_path, "video_frames")
2932
(preprocess, model) = clip_model()
3033

31-
faiss_process(preprocess, model, frames)
32-
34+
extractor = AudioExtractor()
35+
audio_path = extractor.extract_audio(temp_path)
36+
37+
#transcriptions handle
38+
model_path = get_vosk_model()
39+
transciber = Transcriber(model_path)
40+
print("Using model:", model_path)
41+
transcriptions = transciber.transcribe(audio_path)["transcription"]
42+
43+
print("Creating index.bin and metadata.pkl")
44+
faiss_process(preprocess, model, frames, transcriptions)
3345

3446
return {"ready": True}
3547
except Exception as e:

backend/app/services/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.env
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import subprocess
3+
4+
class AudioExtractor:
5+
def __init__(self, ffmpeg_path="ffmpeg"):
6+
self.ffmpeg_path = ffmpeg_path
7+
8+
def extract_audio(self, video_path, output_path=None, format="mp3"):
9+
if not os.path.exists(video_path):
10+
raise FileNotFoundError(f"Video file not found: {video_path}")
11+
12+
# Default output path
13+
if output_path is None:
14+
base, _ = os.path.splitext(video_path)
15+
output_path = f"{base}.{format}"
16+
17+
# Build ffmpeg command
18+
command = [
19+
self.ffmpeg_path,
20+
"-y", # overwrite if file exists
21+
"-i", video_path, # input file
22+
"-vn", # no video
23+
"-ac", "1", # mono
24+
"-ar", "16000", # 16 kHz sample rate (good for ASR)
25+
"-f", format, # output format
26+
output_path
27+
]
28+
29+
# Run command
30+
try:
31+
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32+
except subprocess.CalledProcessError as e:
33+
raise RuntimeError(f"ffmpeg failed: {e.stderr.decode()}")
34+
35+
return output_path

backend/app/services/faiss.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import faiss
66
from tqdm import tqdm
77

8-
def faiss_process(preprocessor, model, frames, device="cpu"):
8+
def faiss_process(preprocessor, model, frames, transcriptions, temp_path=".", device="cpu"):
99
model.to(device)
1010
model.eval()
1111

1212
frame_embeddings = []
13+
metadata = []
1314

14-
for frame_path in tqdm(frames, desc="Processing frames"):
15+
# Build a list of transcriptions for easy lookup
16+
transcription_segments = list(transcriptions.values())
17+
18+
for sec_index, frame_path in enumerate(tqdm(frames, desc="Processing frames")):
1519
image = Image.open(frame_path).convert("RGB")
1620
inputs = preprocessor(images=image, return_tensors="pt").to(device)
1721

@@ -20,23 +24,40 @@ def faiss_process(preprocessor, model, frames, device="cpu"):
2024
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
2125
frame_embeddings.append(image_features.cpu().numpy())
2226

23-
frame_embeddings = np.vstack(frame_embeddings)
27+
# assume frame index = second in video
28+
frame_time = sec_index
29+
30+
# find matching transcription segment
31+
matched_text = ""
32+
for seg in transcription_segments:
33+
if seg["start_sec"] <= frame_time < seg["end_sec"]:
34+
matched_text = seg["text"]
35+
break
36+
37+
metadata.append({
38+
"frame_path": frame_path,
39+
"frame_time": frame_time,
40+
"transcription": matched_text
41+
})
2442

25-
# Dummy text per frame (for now)
26-
transcriptions = [f"Transcription for {fp}" for fp in frames]
43+
print("Metadata done!")
44+
45+
# Stack embeddings
46+
frame_embeddings = np.vstack(frame_embeddings).astype("float32")
2747

2848
# Initialize FAISS index
2949
embedding_dim = frame_embeddings.shape[1]
3050
index = faiss.IndexFlatL2(embedding_dim)
3151
index.add(frame_embeddings)
52+
print("FAISS initialized!")
3253

33-
metadata = [
34-
{"frame_path": frames[i], "transcription": transcriptions[i]} for i in range(len(frames))
35-
]
54+
# Save FAISS index
55+
faiss_path = f"{temp_path}/faiss_index.bin"
56+
faiss.write_index(index, faiss_path)
57+
print(f"FAISS index written to {faiss_path}")
3658

37-
with open("metadata.pkl", "wb") as f:
59+
# Save metadata
60+
meta_path = f"{temp_path}/metadata.pkl"
61+
with open(meta_path, "wb") as f:
3862
pickle.dump(metadata, f)
39-
print("metadata.pkl written")
40-
41-
faiss.write_index(index, "faiss_index.bin")
42-
print("faiss_index.bin written")
63+
print(f"Metadata written to {meta_path}")

backend/app/services/llava_api.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88
import base64
99
from dotenv import load_dotenv
1010

11-
12-
1311
load_dotenv()
1412

1513
url = os.getenv("URL")
1614
api_key = os.getenv("API_KEY")
1715

18-
19-
2016
# Use this function to convert an image file from the filesystem to base64
2117
def image_file_to_base64(image_path):
2218
with open(image_path, 'rb') as f:
@@ -28,7 +24,7 @@ def query_llava(context: dict, question: str) -> str:
2824
try:
2925
data = {
3026
"images": image_file_to_base64(context["frame_path"]),
31-
"prompt": question
27+
"prompt": f"Question:{question}, Transcription for the frame: {context['transcription']}"
3228
}
3329

3430
headers = {'x-api-key': api_key}

backend/app/services/transcribe.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import subprocess, sys, os, json
2+
from datetime import datetime
3+
from vosk import Model, KaldiRecognizer
4+
5+
SAMPLE_RATE = 16000
6+
BYTES_PER_SECOND = SAMPLE_RATE * 2
7+
8+
class Transcriber():
9+
def __init__(self, model_path, window_size_sec=5, stride_sec=1):
10+
"""
11+
window_size_sec: context window size (e.g., 5 seconds)
12+
stride_sec: step size (e.g., 1 second)
13+
"""
14+
self.model = Model(model_path)
15+
self.window_size = window_size_sec
16+
self.stride = stride_sec
17+
18+
def transcribe(self, filename):
19+
rec = KaldiRecognizer(self.model, SAMPLE_RATE)
20+
rec.SetWords(True)
21+
22+
if not os.path.exists(filename):
23+
raise FileNotFoundError(filename)
24+
25+
ffmpeg_command = [
26+
"ffmpeg",
27+
"-nostdin",
28+
"-loglevel", "quiet",
29+
"-i", filename,
30+
"-ar", str(SAMPLE_RATE),
31+
"-ac", "1",
32+
"-f", "s16le",
33+
"-"
34+
]
35+
36+
transcription = {}
37+
start_time = datetime.now()
38+
39+
with subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, bufsize=10**8) as process:
40+
audio = process.stdout.read()
41+
42+
# Convert window/stride to bytes
43+
window_bytes = self.window_size * BYTES_PER_SECOND
44+
stride_bytes = self.stride * BYTES_PER_SECOND
45+
46+
total_len = len(audio)
47+
frame_index = 0
48+
49+
# Slide across audio
50+
for start in range(0, total_len - window_bytes + 1, stride_bytes):
51+
window = audio[start:start + window_bytes]
52+
rec = KaldiRecognizer(self.model, SAMPLE_RATE)
53+
rec.SetWords(True)
54+
55+
if rec.AcceptWaveform(window):
56+
result = json.loads(rec.Result())
57+
text = result.get("text", "")
58+
else:
59+
part = json.loads(rec.PartialResult())
60+
text = part.get("partial", "")
61+
62+
transcription[frame_index] = {
63+
"start_sec": start // BYTES_PER_SECOND,
64+
"end_sec": (start + window_bytes) // BYTES_PER_SECOND,
65+
"text": text
66+
}
67+
frame_index += 1
68+
69+
# Handle tail
70+
if total_len % stride_bytes != 0:
71+
tail = audio[-window_bytes:]
72+
if tail:
73+
rec = KaldiRecognizer(self.model, SAMPLE_RATE)
74+
rec.SetWords(True)
75+
if rec.AcceptWaveform(tail):
76+
result = json.loads(rec.Result())
77+
text = result.get("text", "")
78+
else:
79+
part = json.loads(rec.PartialResult())
80+
text = part.get("partial", "")
81+
82+
transcription[frame_index] = {
83+
"start_sec": (total_len - window_bytes) // BYTES_PER_SECOND,
84+
"end_sec": total_len // BYTES_PER_SECOND,
85+
"text": text
86+
}
87+
88+
end_time = datetime.now()
89+
time_elapsed = end_time - start_time
90+
91+
return {
92+
"start_time": start_time.isoformat(),
93+
"end_time": end_time.isoformat(),
94+
"elapsed_time": str(time_elapsed),
95+
"window_size": self.window_size,
96+
"stride": self.stride,
97+
"transcription": transcription # dict of {index: {start_sec, end_sec, text}}
98+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from app.services.transcribe import Transcriber
2+
import os
3+
import urllib.request
4+
import zipfile
5+
6+
def get_vosk_model(model_name="vosk-model-small-en-us-0.15", target_dir="models"):
7+
# Where the final unzipped model will live
8+
model_path = os.path.join(target_dir, model_name)
9+
10+
os.makedirs(target_dir, exist_ok=True)
11+
12+
if os.path.exists(model_path):
13+
print(f"Model already exists at {model_path}")
14+
return model_path
15+
16+
# Download zip
17+
url = f"https://alphacephei.com/vosk/models/{model_name}.zip"
18+
zip_path = os.path.join(target_dir, f"{model_name}.zip")
19+
20+
print(f"Downloading {url} ...")
21+
urllib.request.urlretrieve(url, zip_path)
22+
23+
print(f"Extracting {zip_path} ...")
24+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
25+
zip_ref.extractall(target_dir)
26+
27+
os.remove(zip_path)
28+
29+
print(f"Model ready at {model_path}")
30+
return model_path

backend/input_video.mp3

624 KB
Binary file not shown.

backend/package-lock.json

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)