-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathpipeline.py
More file actions
297 lines (246 loc) · 13.4 KB
/
pipeline.py
File metadata and controls
297 lines (246 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import os
import re
from typing import Dict, Optional
import gradio as gr
import torch
from librosa import load as libr_load
from soundfile import write as sf_write
from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
def max_ones_window(tensor: torch.Tensor, window_size: int = 30):
# Create a 1D convolution kernel of ones
kernel = torch.ones(window_size, dtype=torch.float32, device=tensor.device)
# Use conv1d: reshape to [N, C, L] format
# input: [1, 1, L], weight: [1, 1, K]
input_tensor = tensor.view(1, 1, -1)
kernel = kernel.view(1, 1, -1)
# Convolution gives rolling sums (like np.convolve with 'valid')
window_sums = torch.nn.functional.conv1d(input_tensor, kernel).flatten()
# Find index of maximum sum
max_start = torch.argmax(window_sums).item()
max_sum = window_sums[max_start].item()
# Extract the slice of the original tensor
max_slice = tensor[max_start:max_start + window_size]
return max_start, max_sum, max_slice
class DiCoWPipeline(AutomaticSpeechRecognitionPipeline):
def __init__(self, *args, diarization_pipeline, **kwargs):
super().__init__(*args, **kwargs)
self.diarization_pipeline = diarization_pipeline
self.type = "seq2seq_whisper"
def get_diarization_mask(self, per_speaker_samples, audio_length):
diarization_mask = torch.zeros(len(per_speaker_samples), audio_length)
for i, speaker_samples in enumerate(per_speaker_samples):
for start, end in speaker_samples:
diarization_mask[i, round(start * 50):round(end * 50)] = 1
return diarization_mask
@staticmethod
def get_stno_mask(diar_mask, s_index):
non_target_mask = torch.ones((diar_mask.shape[0],), dtype=torch.bool)
non_target_mask[s_index] = False
sil_frames = (1 - diar_mask).prod(axis=0)
anyone_else = (1 - diar_mask[non_target_mask]).prod(axis=0)
target_spk = diar_mask[s_index] * anyone_else
non_target_spk = (1 - diar_mask[s_index]) * (1 - anyone_else)
overlapping_speech = diar_mask[s_index] - target_spk
stno_mask = torch.stack([sil_frames, target_spk, non_target_spk, overlapping_speech], axis=0)
return stno_mask
def _process_enrollment_sample(self, samples, idx, stno_mask, original_stno_length):
"""Process enrollment sample with padding to match original size."""
# Find best 30s enrollment window
enrollment_length = 30 * 50
best_start, best_sum, _ = max_ones_window(stno_mask[1], window_size=30 * 50)
# Extract enrollment features
enrollment_features = samples['input_features'][idx][:, best_start * 2:best_start * 2 + enrollment_length * 2]
enrollment_attention = samples['attention_mask'][idx][best_start * 2:best_start * 2 + enrollment_length * 2]
enrollment_stno = stno_mask[:, best_start:best_start + enrollment_length]
return enrollment_features, enrollment_attention, enrollment_stno
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if not isinstance(inputs, str):
raise ValueError("For now input must be a string representing a path to an audio file")
input_dirname = os.path.dirname(inputs)
resampled_path = f'{input_dirname}/resampled.wav'
inp_aud, sr = libr_load(inputs, sr=16_000, mono=True)
sf_write(resampled_path, inp_aud, sr, format='wav')
inputs = resampled_path
generator = super().preprocess(inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s)
samples = next(generator)
diarization_output = self.diarization_pipeline(inputs)
per_speaker_samples = []
for speaker in diarization_output.labels():
per_speaker_samples.append(diarization_output.label_timeline(speaker))
diarization_mask = self.get_diarization_mask(per_speaker_samples, samples['input_features'].shape[-1] // 2)
stno_masks = []
for i, speaker_samples in enumerate(per_speaker_samples):
stno_mask = self.get_stno_mask(diarization_mask, i)
stno_masks.append(stno_mask)
samples['stno_mask'] = torch.stack(stno_masks, axis=0).to(samples['input_features'].device,
dtype=samples['input_features'].dtype)
samples['input_features'] = samples['input_features'].repeat(len(per_speaker_samples), 1, 1)
samples['attention_mask'] = torch.ones(samples['input_features'].shape[0], samples['input_features'].shape[2],
dtype=torch.bool, device=samples['input_features'].device)
if "num_frames" in samples:
del samples["num_frames"]
if hasattr(self.model.config, "use_enrollments") and self.model.config.use_enrollments:
if len(inp_aud) / sr <= 30.0:
# We are in the shortform regime, we don't want to condition, deactivate enrollments
gr.Info(
"If you are experiencing suboptimal performance, consider using a non–self-enrollment conditioned model (e.g., `BUT-FIT/DiCoW_v3_3`) for inputs shorter than 30s.")
# Collect all samples (original + enrollment)
all_input_features = []
all_attention_masks = []
all_stno_masks = []
enroll_input_features = []
enroll_attention_masks = []
enroll_stno_masks = []
original_stno_length = samples['stno_mask'].shape[-1]
for idx, stno_mask in enumerate(samples['stno_mask']):
# Add original sample
all_input_features.append(samples['input_features'][idx])
all_attention_masks.append(samples['attention_mask'][idx])
all_stno_masks.append(stno_mask)
# Add enrollment sample (padded to original size)
enrollment_features, enrollment_attention, enrollment_stno = self._process_enrollment_sample(
samples, idx, stno_mask, original_stno_length
)
enroll_input_features.append(enrollment_features)
enroll_attention_masks.append(enrollment_attention)
enroll_stno_masks.append(enrollment_stno)
# Stack all samples
samples['input_features'] = torch.stack(all_input_features, dim=0)
samples['attention_mask'] = torch.stack(all_attention_masks, dim=0)
samples['stno_mask'] = torch.stack(all_stno_masks, dim=0)
samples["enrollments"] = {
"input_features": torch.stack(enroll_input_features, dim=0),
"attention_mask": torch.stack(enroll_attention_masks, dim=0),
"stno_mask": torch.stack(enroll_stno_masks, dim=0),
}
yield samples
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
segment_size = model_inputs.pop("segment_size", None)
is_last = model_inputs.pop("is_last")
if stride is not None and segment_size is not None:
raise ValueError("segment_size must be used only when stride is None")
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
if "input_features" in model_inputs:
inputs = model_inputs.pop("input_features")
elif "input_values" in model_inputs:
inputs = model_inputs.pop("input_values")
else:
raise ValueError(
"Seq2Seq speech recognition model requires either a "
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
)
# custom processing for Whisper timestamps and word-level timestamps
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
generate_kwargs["input_features"] = inputs
tokens = self.model.generate(
attention_mask=attention_mask,
**generate_kwargs,
**model_inputs,
)
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
token_timestamps = [
torch.cat([segment["token_timestamps"] for segment in segment_list])
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else:
out = {"tokens": tokens}
if self.type == "seq2seq_whisper":
if stride is not None:
out["stride"] = stride
# Leftover
extra = model_inputs
return {"is_last": is_last, **out, **extra}
@staticmethod
def postprocess_text(input_string):
pattern = r"<\|([\d.]+)\|>"
matches = re.finditer(pattern, input_string)
timestamps = [(float(match.group(1)), match.start(), match.end()) for match in matches]
if not timestamps or len(timestamps) <= 2:
return input_string
# The whole algorithm boils down to either removing the entire chain of timestamps - the case where all of them are the same (i.e. ...<a><a><a>... -> ......)
# or removing all but the corner ones (i.e. <a><b><c><c><d> -> <a><d>) - the case where we have end and start timestamps and some rubbish in-between.
processed_timestamps = []
i = 0
while i < len(timestamps):
ts, st, et = timestamps[i]
if i < len(timestamps) - 1 or processed_timestamps[-1][-1] != st:
processed_timestamps.append((ts, st, et))
if i == len(timestamps) - 1:
break
j = i + 1
nts, nst, net = timestamps[j]
all_equal_ts = nts == ts
prev_et = et
while nst - prev_et == 0:
# Skip all but the last timestamp. If the last in the chain has the same TS as the processed_timestamps tail, pop processed_timestamps.
# If not, append it while skipping all the previous ones.
# In other words, keep appending (-2, X, X) as long as the next one is in the chain and then decide what to do with the last one if the next one is not in the chain.
if j == len(timestamps) - 1:
if net == len(input_string) and prev_et != nst:
processed_timestamps.append((nts, nst, net))
j += 1
break
else:
if timestamps[j + 1][1] - net == 0:
processed_timestamps.append((-2, nst, net))
else:
if all_equal_ts:
# If there's a chain of eq timestamps at the beginning, we need to keep at least one.
if i != 0:
processed_timestamps[i] = (-1, st, et)
processed_timestamps.append((-2, nst, net))
else:
# If there's a chain of tags at the beginning with all ts not being equal, we need to keep the last one.
if i == 0:
processed_timestamps[i] = (-2, st, et)
processed_timestamps.append((nts, nst, net))
j += 1
break
j += 1
prev_et = net
nts, nst, net = timestamps[j]
all_equal_ts = all_equal_ts and nts == ts
i = j
result = []
prev_end = 0
for i, (ts, st, et) in enumerate(processed_timestamps):
result.append(f'{input_string[prev_end:st]}')
if ts == -1:
result.append(' ')
elif ts == -2:
# Empty string, so no need to append anything
pass
else:
result.append(f'<|{ts:.2f}|>')
prev_end = et
return "".join(result)
def postprocess(
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
):
per_spk_outputs = self.tokenizer.batch_decode(
model_outputs[0]['tokens'], decode_with_timestamps=True, skip_special_tokens=True
)
formatted_lines = []
for spk, text in enumerate(per_spk_outputs):
processed_text = self.postprocess_text(text)
# Split on each timestamp pair
# This regex finds "<|start|>...<|end|>" pairs with everything inside
segments = re.findall(r"(<\|\d+\.\d+\|>.*?<\|\d+\.\d+\|>)", processed_text)
# Build the output for this speaker
speaker_header = f"🗣️ Speaker {spk}:\n"
speaker_body = "\n".join(segments)
formatted_lines.append(f"{speaker_header}{speaker_body}")
full_text = "\n\n".join(formatted_lines)
return {"text": full_text, "per_spk_outputs": per_spk_outputs}