Skip to content

Commit 81ca1f2

Browse files
Merge pull request #3977 from graham0824:dev/jiunkaiy/multi_attn
LiteRT-PiperOrigin-RevId: 829045132
2 parents b760e6f + 8f2fb00 commit 81ca1f2

File tree

6 files changed

+616
-2
lines changed

6 files changed

+616
-2
lines changed

litert/vendors/qualcomm/core/transformation/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_test(
2828
"//litert/vendors/qualcomm/core/builders:matmul_op_builder",
2929
"//litert/vendors/qualcomm/core/builders:quantize_op_builder",
3030
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
31+
"//litert/vendors/qualcomm/core/builders:select_op_builder",
3132
"//litert/vendors/qualcomm/core/builders:slice_op_builder",
3233
"//litert/vendors/qualcomm/core/builders:softmax_op_builder",
3334
"//litert/vendors/qualcomm/core/builders:transpose_op_builder",
@@ -84,9 +85,13 @@ cc_library(
8485
deps = [
8586
"//litert/vendors/qualcomm/core:op_code",
8687
"//litert/vendors/qualcomm/core:tensor_pool",
88+
"//litert/vendors/qualcomm/core/builders:cast_op_builder",
8789
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
90+
"//litert/vendors/qualcomm/core/builders:elementwise_op_builder",
91+
"//litert/vendors/qualcomm/core/builders:pack_op_builder",
8892
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
8993
"//litert/vendors/qualcomm/core/builders:split_op_builder",
94+
"//litert/vendors/qualcomm/core/builders:transpose_op_builder",
9095
"//litert/vendors/qualcomm/core/builders:unpack_op_builder",
9196
"//litert/vendors/qualcomm/core/utils:log",
9297
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",

litert/vendors/qualcomm/core/transformation/README.md

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ graph TB
3434
| KV_LEN | KV Cache Length | 1280 |
3535
* Notation reference: [AI Edge Torch](https://github.com/google-ai-edge/ai-edge-torch)
3636

37+
### Gemma3 (Prefill)
3738
Original MHA (Multi-head Attention) in Gemma3 prefill graph
3839
```mermaid
3940
graph TB
@@ -146,4 +147,113 @@ graph TB
146147
VSlice@{shape: text}
147148
KSliceOut@{shape: text}
148149
VSliceOut@{shape: text}
150+
```
151+
152+
### Multi-head Attention with MaskedSoftmax via Select
153+
The figure below shows multi-head attention with `MaskedSoftmax` implemented using `Select`. All MHAs with `Select` share the same structure and use the same `NotEqual` output as input to their `Select` operations.
154+
```mermaid
155+
graph TB
156+
Q1 --> |"[B, T, N, H]"| Mul1["Mul"]
157+
K1 --> |"[B, KV_LEN, N, H]"| Mul2["Mul"]
158+
Mul1 --> |"[B, T, N, H]"| Transpose1["Transpose"]
159+
Mul2 --> |"[B, KV_LEN, N, H]"| Transpose2["Transpose"]
160+
Transpose1 --> |"[B, N, T, H]"| MatMul1["MatMul"]
161+
Transpose2 --> |"[B, N, H, KV_LEN]"| MatMul1
162+
MatMul1 --> |"[B, N, T, KV_LEN]"| Select
163+
Select --> |"[B, N, T, KV_LEN]"| Softmax
164+
Softmax --> |"[B, N, T, KV_LEN]"| MatMul2["MatMul
165+
(adj_y = true)"]
166+
V1 --> |"[B, KV_LEN, N, H]"| Transpose3["Transpose"]
167+
Transpose3 --> |"[B, N, H, KV_LEN]"| MatMul2
168+
MatMul2 --> |"[B, N, H, T]"| Transpose4["Transpose"]
169+
Transpose4 --> |"[B, T, N, H]"| Out
170+
Mask --> |"[B, T, KV_LEN]"| Reshape
171+
Reshape --> |"[B, 1, T, KV_LEN]"| NotEqual
172+
NotEqual --> |"[B, 1, T, KV_LEN]"| Select
173+
Q2 --> MHA1["MHAs with Select"]
174+
K2 --> MHA1["MHAs with Select"]
175+
V2 --> MHA1["MHAs with Select"]
176+
MHA1 --> Out2
177+
NotEqual --> MHA1["MHAs with Select"]
178+
NotEqual --> MHA2["..."]
179+
Q1@{ shape: text}
180+
K1@{ shape: text}
181+
V1@{ shape: text}
182+
Q2@{ shape: text}
183+
K2@{ shape: text}
184+
V2@{ shape: text}
185+
Mask@{ shape: text}
186+
Out@{ shape: sm-circ}
187+
Out2@{ shape: sm-circ}
188+
MHA2@{ shape: text }
189+
```
190+
The `Reshape → NotEqual → Select` pattern can be optimized through the following operations. It is important to note that the mask produced by the `Mul` operation is reused by all subsequent SHAs involving `Add` for `MaskedSoftmax`.
191+
```mermaid
192+
graph TB
193+
Mask --> |"[B, T, KV_LEN]"| Equal
194+
Equal --> |"[B, T, KV_LEN]"| Cast
195+
Cast --> |"[B, T, KV_LEN]"| Mul
196+
Mul --> |"[B, T, KV_LEN]"| SHA1["Mask via Add"]
197+
SHA1@{ shape: text }
198+
Mask@{ shape: text}
199+
```
200+
With `Mask via Add`, the overall multi-head attention can be transformed to the following SHAs
201+
```mermaid
202+
graph TB
203+
Q1 --> |"[B, T, N, H]"| Unpack1["Unpack"]
204+
Unpack1 --> |"[B, T, H]"| MHASHA1["SHA with Add"]
205+
Unpack1 --> MHASHA2["SHA with Add"]
206+
Unpack1 --> MHASHA3["..."]
207+
K1 --> |"[B, KV_LEN, N, H]"| Unpack2["Unpack"]
208+
Unpack2 --> |"[B, KV_LEN, H]"| MHASHA1["SHA with Add"]
209+
Unpack2 --> MHASHA2
210+
Unpack2 --> MHASHA3
211+
V1 --> |"[B, KV_LEN, N, H]"| Unpack3["Unpack"]
212+
Unpack3 --> |"[B, KV_LEN, H]"| MHASHA1["SHA with Add"]
213+
Unpack3 --> MHASHA2
214+
Unpack3 --> MHASHA3
215+
Mask --> MHASHA1
216+
Mask --> MHASHA2
217+
Mask --> MHASHA3
218+
Mask --> MHA1["MHAs with Add"]
219+
Q2 --> MHA1
220+
K2 --> MHA1
221+
V2 --> MHA1
222+
MHA1 --> Out2
223+
Mask --> MHA2["..."]
224+
MHASHA1 --> |"[B, T, H]"| Pack
225+
MHASHA2 --> Pack
226+
MHASHA3 --> Pack
227+
Pack --> |"[B, T, N, H]"| Out
228+
Q1@{ shape: text}
229+
K1@{ shape: text}
230+
V1@{ shape: text}
231+
Q2@{ shape: text}
232+
K2@{ shape: text}
233+
V2@{ shape: text}
234+
Mask(Mask via Add)@{ shape: text}
235+
Out@{ shape: sm-circ}
236+
Out2@{ shape: sm-circ}
237+
MHASHA3@{ shape: text }
238+
MHA2@{ shape: text }
239+
```
240+
with `SHA with add` below.
241+
```mermaid
242+
graph TB
243+
Q_unpack --> |"[B, T, H]"| Mul1["Mul"]
244+
K_unpack --> |"[B, KV_LEN, H]"| Mul2["Mul"]
245+
Mul2 --> |"[B, KV_LEN, H]"| Transpose2["Transpose"]
246+
Mul1 --> |"[B, T, H]"| MatMul1["MatMul"]
247+
Transpose2 --> |"[B, H, KV_LEN]"| MatMul1
248+
MatMul1 --> |"[B, T, KV_LEN]"| Add
249+
Mask --> Add
250+
Add --> |"[B, T, KV_LEN]"| Softmax
251+
Softmax --> |"[B, T, KV_LEN]"| MatMul2["MatMul"]
252+
V_unpack --> |"[B, KV_LEN, H]"| MatMul2
253+
MatMul2 --> |"[B, T, H]"| Out
254+
Q_unpack@{ shape: text}
255+
K_unpack@{ shape: text}
256+
V_unpack@{ shape: text}
257+
Mask(Mask via Add)@{ shape: text}
258+
Out@{ shape: sm-circ}
149259
```

litert/vendors/qualcomm/core/transformation/graph_to_graph.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,22 @@ void GraphToGraphTransform(const G2GConfig g2g_option,
204204
QnnOpCode::kReshape};
205205
Transform(validate_op_config, ops, tensor_pool, fast_vlm_mha_prefill,
206206
OptimizeMHAFastVlmPrefill);
207+
208+
// Attention Optimization
209+
const std::vector<QnnOpCode> attn = {
210+
QnnOpCode::kElementWiseMultiply,
211+
QnnOpCode::kElementWiseMultiply,
212+
QnnOpCode::kTranspose,
213+
QnnOpCode::kTranspose,
214+
QnnOpCode::kMatMul,
215+
QnnOpCode::kReshape,
216+
QnnOpCode::kElementWiseBinary,
217+
QnnOpCode::kElementWiseSelect,
218+
QnnOpCode::kSoftmax,
219+
QnnOpCode::kTranspose,
220+
QnnOpCode::kMatMul,
221+
QnnOpCode::kTranspose,
222+
};
223+
Transform(validate_op_config, ops, tensor_pool, attn, OptimizeMHAAttn);
207224
}
208225
} // namespace qnn

0 commit comments

Comments
 (0)