-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiarization 2.py
More file actions
343 lines (276 loc) · 12.5 KB
/
diarization 2.py
File metadata and controls
343 lines (276 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import os
import re
import subprocess
import requests
import json
import time
from pydub import AudioSegment
from pyannote.audio import Pipeline
import uuid
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
from langdetect import detect # Language detection library
# Function to detect the language of the text
def detect_language(text):
try:
return detect(text)
except:
return "unknown"
# Function to download audio from URL
def download_audio(url, download_path):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(download_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return download_path
except Exception as e:
print(f"Error downloading audio file: {e}")
return None
# Function to transcribe audio file using Whisper
def transcribe_audio(file_path, model='medium', word_timestamps=True):
command = [
"whisper", str(file_path),
"--model", model,
"--word_timestamps", str(word_timestamps),
"--output_format", "json"
]
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
return "", result.stderr
return result.stdout, ""
# Function to parse the transcription text into a list of dictionaries
def parse_transcription(transcription):
pattern = r'\[(\d{2}:\d{2}\.\d{3}) --> (\d{2}:\d{02}\.\d{3})\] (.+)'
matches = re.findall(pattern, transcription)
result = []
for match in matches:
start_time, end_time, text = match
result.append({
"start_time": start_time,
"end_time": end_time,
"text": text
})
return result
# Function to perform speaker diarization and save speaker segments
def perform_diarization(audio_file_path, output_dir, file_id):
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token="YOUR_HF_AUTH_TOKEN")
except Exception as e:
print(f"Error loading pyannote pipeline: {e}")
return None, None
wav_file_path = None
try:
if not audio_file_path.lower().endswith('.wav'):
audio = AudioSegment.from_file(audio_file_path)
wav_file_path = audio_file_path.rsplit('.', 1)[0] + '.wav'
audio.export(wav_file_path, format='wav')
audio_file_path = wav_file_path
diarization = pipeline(audio_file_path)
audio = AudioSegment.from_file(audio_file_path)
segments = []
for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
segment = audio[turn.start * 1000: turn.end * 1000]
segment_path = os.path.join(output_dir, f"{file_id}_speaker_{speaker}_{i}_{turn.start:.2f}_{turn.end:.2f}.wav")
segment.export(segment_path, format="wav")
segments.append((segment_path, turn.start, turn.end))
return segments, audio, wav_file_path
except Exception as e:
print(f"An error occurred: {e}")
return None, None
# Function to adjust transcription timestamps relative to the full audio file
def adjust_timestamps(transcriptions, segment_start, segment_end, full_audio_length):
adjusted = []
for entry in transcriptions:
start_time = entry['start_time']
end_time = entry['end_time']
def timestamp_to_seconds(ts):
minutes, seconds = map(float, ts.split(':'))
return minutes * 60 + seconds
start_seconds = timestamp_to_seconds(start_time)
end_seconds = timestamp_to_seconds(end_time)
adjusted_start = segment_start + start_seconds
adjusted_end = segment_start + end_seconds
if adjusted_end > full_audio_length:
adjusted_end = full_audio_length
def seconds_to_timestamp(seconds):
minutes, seconds = divmod(seconds, 60)
return f"{int(minutes):02}:{seconds:05.3f}"
adjusted.append({
"start_time": seconds_to_timestamp(adjusted_start),
"end_time": seconds_to_timestamp(adjusted_end),
"text": entry['text']
})
return adjusted
# Function to merge parsed transcriptions
def merge_transcriptions(transcriptions):
merged = []
for entry in transcriptions:
# Detect the language of the current text
language = detect_language(entry['text'])
# Replace non-English text with #
if language == 'en':
merged.append(entry)
# if merged and merged[-1]['end_time'] == entry['start_time']:
# merged[-1]['text'] += " " + entry['text']
# merged[-1]['end_time'] = entry['end_time']
# elif merged and merged[-1]['end_time'] > entry['start_time']:
# Handle overlap by merging texts and setting end_time to the maximum of both
# merged[-1]['text'] += " " + entry['text']
# merged[-1]['end_time'] = max(merged[-1]['end_time'], entry['end_time'])
else:
merged.append({
"start_time": entry['start_time'],
"end_time": entry['end_time'],
"text": '#' * len(entry['text']) # Use # for non-English segments
})
return merged
# Function to fetch mp3 url from the api
def fetch_mp3_url(api_url):
try:
response = requests.get(api_url)
response.raise_for_status()
audio_data = response.json()
return audio_data.get('mp3url')
except Exception as e:
print(f"Error fetching audio URL from API: {e}")
return None
# Function to send transcription result to API
def send_transcription(api_url, mp3_url, transcription):
payload = {
'mp3_url': mp3_url,
'speech_to_text': json.dumps(transcription)
}
response = requests.patch(api_url, params=payload)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Failed to send transcription: {response.status_code}")
# Function to handle the processing of each MP3 URL
def process_mp3_url(mp3_url, set_transcription_api):
try:
base_directory = "temp_4" # Replace with your desired base directory
if not os.path.exists(base_directory):
os.makedirs(base_directory)
# Create a unique subdirectory for each audio file
file_id = str(uuid.uuid4())[:8]
unique_subdir = os.path.join(base_directory, file_id)
os.makedirs(unique_subdir)
audio_file_path = os.path.join(unique_subdir, "downloaded_audio.mp3")
# Download the audio file to the unique subdirectory
download_path = download_audio(mp3_url, audio_file_path)
if not download_path:
return
# Perform diarization on the downloaded audio file
segments, full_audio, wav_file_path = perform_diarization(download_path, unique_subdir, file_id)
if segments is None:
return
full_audio_length = len(full_audio) / 1000.0
combined_transcription = []
for segment_path, segment_start, segment_end in segments:
transcription, error = transcribe_audio(segment_path)
if error:
print(f"Transcription error for {segment_path}: {error}")
continue
parsed_transcription = parse_transcription(transcription)
adjusted_transcription = adjust_timestamps(parsed_transcription, segment_start, segment_end, full_audio_length)
combined_transcription.extend(adjusted_transcription)
merged_transcription = merge_transcriptions(combined_transcription)
print(f"Merged Transcription for {mp3_url}: {merged_transcription}")
response = send_transcription(set_transcription_api, mp3_url, merged_transcription)
print(f"API Response: {response}")
except Exception as e:
print(f"An error occurred while processing {mp3_url}: {e}")
finally:
# Clean up files in the unique directory if needed
try:
for file_path in os.listdir(unique_subdir):
os.remove(os.path.join(unique_subdir, file_path))
os.rmdir(unique_subdir)
except Exception as e:
print(f"Error cleaning up directory {unique_subdir}: {e}")
# audio_file_path = "downloaded_audio_" + str(uuid.uuid4())[:8] + ".mp3"
# download_path = download_audio(mp3_url, audio_file_path)
# if not download_path:
# return
# segments, full_audio, wav_file_path = perform_diarization(download_path)
# if segments is None:
# return
# full_audio_length = len(full_audio) / 1000.0
# combined_transcription = []
# created_files = []
# try:
# for segment_path, segment_start, segment_end in segments:
# created_files.append(segment_path)
# transcription, error = transcribe_audio(segment_path)
# if error:
# print(f"Transcription error for {segment_path}: {error}")
# continue
# parsed_transcription = parse_transcription(transcription)
# adjusted_transcription = adjust_timestamps(parsed_transcription, segment_start, segment_end, full_audio_length)
# combined_transcription.extend(adjusted_transcription)
# merged_transcription = merge_transcriptions(combined_transcription)
# print(f"Merged Transcription: {merged_transcription}")
# response = send_transcription(set_transcription_api, mp3_url, merged_transcription)
# print(f"API Response: {response}")
# finally:
# for file_path in created_files:
# try:
# os.remove(file_path)
# print(f"Deleted file: {file_path}")
# except Exception as e:
# print(f"Error deleting file {file_path}: {e}")
# if os.path.exists(audio_file_path):
# try:
# os.remove(audio_file_path)
# print(f"Deleted downloaded audio file: {audio_file_path}")
# except Exception as e:
# print(f"Error deleting downloaded audio file {audio_file_path}: {e}")
# if wav_file_path and os.path.exists(wav_file_path):
# try:
# os.remove(wav_file_path)
# print(f"Deleted intermediate WAV file: {wav_file_path}")
# except Exception as e:
# print(f"Error deleting intermediate WAV file {wav_file_path}: {e}")
# except Exception as e:
# print(f"An error occurred while processing {mp3_url}: {e}")
# function to intiate the process
def initiate_s2t(get_url_api, set_transcription_api):
#get MP3 URL
mp3_url = fetch_mp3_url(get_url_api)
if mp3_url :
process_text = [ {'start_time': '00:00', 'end_time': '00:00', 'text': 'processing'}]
response = send_transcription(set_transcription_api, mp3_url, process_text)
process_mp3_url(mp3_url, set_transcription_api)
else:
print("No new MP3 URL found or same URL as before. Sleeping for 5 seconds.")
time.sleep(5) # Sleep for 5 seconds
# Main function to handle downloading, diarizing, transcribing, and sending the transcription in parallel
def main():
get_url_api = "https://tabsons-fastapi-g55rbik64q-el.a.run.app/get_first_mp3_url/"
set_transcription_api = "https://tabsons-fastapi-g55rbik64q-el.a.run.app/set_speech_to_text"
last_mp3_url = None
cpu_count = multiprocessing.cpu_count()
print("\nCPU Count:" + str(cpu_count))
num_threads = cpu_count if cpu_count > 0 else 1 # Number of parallel threads
num_threads = 2
while True:
try:
# Running the tasks in parallel
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(initiate_s2t, get_url_api, set_transcription_api) for _ in range(num_threads)]
for future in as_completed(futures):
print("\nProcessing completed.")
try:
future.result()
except torch.cuda.OutOfMemoryError as e:
print(f"CUDA Out of Memory: {e}")
torch.cuda.empty_cache() # Clear cache if OOM error occurs
time.sleep(5)
except Exception as e:
print(f"An error occurred: {e}")
time.sleep(5)
if __name__ == "__main__":
main()