-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathhailo-compatibility-changes.patch
More file actions
174 lines (157 loc) · 6.79 KB
/
hailo-compatibility-changes.patch
File metadata and controls
174 lines (157 loc) · 6.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
diff --git a/whisper/__init__.py b/whisper/__init__.py
index e210718..f01fe07 100644
--- a/whisper/__init__.py
+++ b/whisper/__init__.py
@@ -152,7 +152,14 @@ def load_model(
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
- model.load_state_dict(checkpoint["model_state_dict"])
+ convs = [x for x in checkpoint["model_state_dict"].keys() if 'conv' in x and 'weight' in x]
+ for conv in convs:
+ layer = checkpoint["model_state_dict"][conv]
+ checkpoint["model_state_dict"][conv] = torch.unsqueeze(layer, dim=2)
+
+ #import pdb;pdb.set_trace()
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
diff --git a/whisper/audio.py b/whisper/audio.py
index 826250f..0e153a1 100644
--- a/whisper/audio.py
+++ b/whisper/audio.py
@@ -13,7 +13,7 @@ from .utils import exact_div
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
-CHUNK_LENGTH = 30
+CHUNK_LENGTH = 10
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
diff --git a/whisper/model.py b/whisper/model.py
index e537447..fa0f112 100644
--- a/whisper/model.py
+++ b/whisper/model.py
@@ -20,7 +20,7 @@ try:
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
-
+SDPA_AVAILABLE = False # forcing it to avoid issues
@dataclass
class ModelDimensions:
@@ -50,7 +50,7 @@ class Linear(nn.Linear):
)
-class Conv1d(nn.Conv1d):
+class Conv2d(nn.Conv2d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
@@ -118,7 +118,7 @@ class MultiHeadAttention(nn.Module):
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * 1.0
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
@@ -176,9 +176,9 @@ class AudioEncoder(nn.Module):
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
- self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
- self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+ self.conv1 = Conv2d(n_mels, n_state, kernel_size=(1, 3), padding=(0, 1))
+ self.conv2 = Conv2d(n_state, n_state, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
+ self.positional_embedding = sinusoids(n_ctx, n_state)
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
@@ -192,8 +192,10 @@ class AudioEncoder(nn.Module):
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
+ x = x.flatten(2).permute(0, 2, 1)
+ # print(x.shape[1:])
+ # print(self.positional_embedding.shape)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
@@ -211,7 +213,8 @@ class TextDecoder(nn.Module):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
- self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+ # self.positional_embedding = nn.Parameter(torch.empty(448, n_state)) # TODO: 448 is the whisper tiny context window size
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) # TODO: 448 is the whisper tiny context window size
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
@@ -220,10 +223,27 @@ class TextDecoder(nn.Module):
]
)
self.ln = LayerNorm(n_state)
-
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
+ def split_conv2d_method(self, x): # method to split the final Matmul into 4 smaller Matmuls
+ vocab_size = self.token_embedding.weight.shape[0]
+ chunk_size = vocab_size // 4
+ logit_chunks = []
+
+ W = self.token_embedding.weight.to(x.dtype)
+
+ for i in range(4):
+ start = i * chunk_size
+ end = (i + 1) * chunk_size if i < 3 else vocab_size # handle remainder
+ W_chunk = W[start:end] # shape: (chunk_size, 384)
+ logits_chunk = torch.matmul(x, W_chunk.T) # shape: (1, 32, chunk_size)
+ logit_chunks.append(logits_chunk)
+
+ logits = torch.cat(logit_chunks, dim=-1) # shape: (1, 32, 51865)
+
+ return logits
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
@@ -231,20 +251,32 @@ class TextDecoder(nn.Module):
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
+
+ # Interpolate xa to match the original Whisper encoder's output length
+ #xa = F.interpolate(xa.permute(0, 2, 1), size=1500, mode='linear').permute(0, 2, 1) # improves accuracy a bit
+
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
+ x = x.unsqueeze(1)
+ x = x.transpose(1, -1)
+ x = x.flatten(2).permute(0, 2, 1)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
- x = self.ln(x)
- logits = (
- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
- ).float()
+ x = self.ln(x) # Shape: 1, 32, 384
+ #logits = (
+ # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) # original output Matmul, too big to be supported on Hailo-8
+ #).float()
+
+ logits = self.split_conv2d_method(x) # Shape: 1, 32, 51865
+
+ # print(logits.shape)
return logits
@@ -255,7 +287,7 @@ class Whisper(nn.Module):
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
- self.dims.n_audio_ctx,
+ self.dims.n_audio_ctx // 3, # scaled down the positinal encoding
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,