-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
72 lines (51 loc) · 2.07 KB
/
app.py
File metadata and controls
72 lines (51 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torchaudio
import IPython
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return "".join([self.labels[i] for i in indices])
print(torch.__version__)
print(torchaudio.__version__)
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
SPEECH_FILE = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
model = bundle.get_model().to(device)
print(model.__class__)
IPython.display.Audio(SPEECH_FILE)
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
with torch.inference_mode():
features, _ = model.extract_features(waveform)
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu(), interpolation="nearest")
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
with torch.inference_mode():
emission, _ = model(waveform)
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])