You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Qualcomm AI Engine Direct - Optimize Attention block.
Summary:
- Enable MHA2SHA support for multiple attention blocks with shared masking operations.
- Add unit tests to validate the new MHA2SHA functionality.
Original MHA (Multi-head Attention) in Gemma3 prefill graph
38
39
```mermaid
39
40
graph TB
@@ -146,4 +147,113 @@ graph TB
146
147
VSlice@{shape: text}
147
148
KSliceOut@{shape: text}
148
149
VSliceOut@{shape: text}
149
-
```
150
+
```
151
+
152
+
### Multi-head Attention with MaskedSoftmax via Select
153
+
The figure below shows multi-head attention with `MaskedSoftmax` implemented using `Select`. All MHAs with `Select` share the same structure and use the same `NotEqual` output as input to their `Select` operations.
154
+
```mermaid
155
+
graph TB
156
+
Q1 --> |"[B, T, N, H]"| Mul1["Mul"]
157
+
K1 --> |"[B, KV_LEN, N, H]"| Mul2["Mul"]
158
+
Mul1 --> |"[B, T, N, H]"| Transpose1["Transpose"]
159
+
Mul2 --> |"[B, KV_LEN, N, H]"| Transpose2["Transpose"]
160
+
Transpose1 --> |"[B, N, T, H]"| MatMul1["MatMul"]
161
+
Transpose2 --> |"[B, N, H, KV_LEN]"| MatMul1
162
+
MatMul1 --> |"[B, N, T, KV_LEN]"| Select
163
+
Select --> |"[B, N, T, KV_LEN]"| Softmax
164
+
Softmax --> |"[B, N, T, KV_LEN]"| MatMul2["MatMul
165
+
(adj_y = true)"]
166
+
V1 --> |"[B, KV_LEN, N, H]"| Transpose3["Transpose"]
167
+
Transpose3 --> |"[B, N, H, KV_LEN]"| MatMul2
168
+
MatMul2 --> |"[B, N, H, T]"| Transpose4["Transpose"]
169
+
Transpose4 --> |"[B, T, N, H]"| Out
170
+
Mask --> |"[B, T, KV_LEN]"| Reshape
171
+
Reshape --> |"[B, 1, T, KV_LEN]"| NotEqual
172
+
NotEqual --> |"[B, 1, T, KV_LEN]"| Select
173
+
Q2 --> MHA1["MHAs with Select"]
174
+
K2 --> MHA1["MHAs with Select"]
175
+
V2 --> MHA1["MHAs with Select"]
176
+
MHA1 --> Out2
177
+
NotEqual --> MHA1["MHAs with Select"]
178
+
NotEqual --> MHA2["..."]
179
+
Q1@{ shape: text}
180
+
K1@{ shape: text}
181
+
V1@{ shape: text}
182
+
Q2@{ shape: text}
183
+
K2@{ shape: text}
184
+
V2@{ shape: text}
185
+
Mask@{ shape: text}
186
+
Out@{ shape: sm-circ}
187
+
Out2@{ shape: sm-circ}
188
+
MHA2@{ shape: text }
189
+
```
190
+
The `Reshape → NotEqual → Select` pattern can be optimized through the following operations. It is important to note that the mask produced by the `Mul` operation is reused by all subsequent SHAs involving `Add` for `MaskedSoftmax`.
191
+
```mermaid
192
+
graph TB
193
+
Mask --> |"[B, T, KV_LEN]"| Equal
194
+
Equal --> |"[B, T, KV_LEN]"| Cast
195
+
Cast --> |"[B, T, KV_LEN]"| Mul
196
+
Mul --> |"[B, T, KV_LEN]"| SHA1["Mask via Add"]
197
+
SHA1@{ shape: text }
198
+
Mask@{ shape: text}
199
+
```
200
+
With `Mask via Add`, the overall multi-head attention can be transformed to the following SHAs
201
+
```mermaid
202
+
graph TB
203
+
Q1 --> |"[B, T, N, H]"| Unpack1["Unpack"]
204
+
Unpack1 --> |"[B, T, H]"| MHASHA1["SHA with Add"]
205
+
Unpack1 --> MHASHA2["SHA with Add"]
206
+
Unpack1 --> MHASHA3["..."]
207
+
K1 --> |"[B, KV_LEN, N, H]"| Unpack2["Unpack"]
208
+
Unpack2 --> |"[B, KV_LEN, H]"| MHASHA1["SHA with Add"]
209
+
Unpack2 --> MHASHA2
210
+
Unpack2 --> MHASHA3
211
+
V1 --> |"[B, KV_LEN, N, H]"| Unpack3["Unpack"]
212
+
Unpack3 --> |"[B, KV_LEN, H]"| MHASHA1["SHA with Add"]
0 commit comments