Skip to content

Commit 3fd1d26

Browse files
committed
Qualcomm AI Engine Direct - Optimize Embedding Gemma.
Summary: - Add MHA to SHA optimization. - Add tests for the optimization.
1 parent 55e6099 commit 3fd1d26

File tree

6 files changed

+550
-1
lines changed

6 files changed

+550
-1
lines changed

litert/vendors/qualcomm/core/transformation/BUILD

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ cc_library(
4848
"nobuilder",
4949
],
5050
deps = [
51+
":embedding_gemma",
5152
":mask",
5253
":matmul_convert",
5354
":mha_to_sha",
@@ -110,3 +111,52 @@ cc_library(
110111
"@qairt//:qnn_lib_headers",
111112
],
112113
)
114+
115+
cc_library(
116+
name = "embedding_gemma",
117+
srcs = ["embedding_gemma.cc"],
118+
hdrs = ["embedding_gemma.h"],
119+
tags = [
120+
# Don't build/test in OS until qnn is available.
121+
"nobuilder",
122+
],
123+
deps = [
124+
"//litert/vendors/qualcomm/core:tensor_pool",
125+
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
126+
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
127+
"//litert/vendors/qualcomm/core/builders:split_op_builder",
128+
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",
129+
"//litert/vendors/qualcomm/core/wrappers:tensor_wrapper",
130+
"@com_google_absl//absl/strings",
131+
"@qairt//:qnn_lib_headers",
132+
],
133+
)
134+
135+
cc_test(
136+
name = "embedding_gemma_test",
137+
srcs = [
138+
"embedding_gemma_test.cc",
139+
],
140+
tags = [
141+
# Don't build/test in OS until qnn is available.
142+
"nobuilder",
143+
],
144+
deps = [
145+
":graph_to_graph",
146+
"//litert/vendors/qualcomm/core:op_code",
147+
"//litert/vendors/qualcomm/core:tensor_pool",
148+
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
149+
"//litert/vendors/qualcomm/core/builders:elementwise_op_builder",
150+
"//litert/vendors/qualcomm/core/builders:matmul_op_builder",
151+
"//litert/vendors/qualcomm/core/builders:op_builder",
152+
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
153+
"//litert/vendors/qualcomm/core/builders:softmax_op_builder",
154+
"//litert/vendors/qualcomm/core/builders:split_op_builder",
155+
"//litert/vendors/qualcomm/core/builders:transpose_op_builder",
156+
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",
157+
"//litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper",
158+
"//litert/vendors/qualcomm/core/wrappers:tensor_wrapper",
159+
"@com_google_googletest//:gtest_main",
160+
"@qairt//:qnn_lib_headers",
161+
],
162+
)
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "litert/vendors/qualcomm/core/transformation/embedding_gemma.h"
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <functional>
9+
#include <vector>
10+
11+
#include "QnnTypes.h" // from @qairt
12+
#include "absl/strings/str_cat.h" // from @com_google_absl
13+
#include "litert/vendors/qualcomm/core/builders/concatenation_op_builder.h"
14+
#include "litert/vendors/qualcomm/core/builders/reshape_op_builder.h"
15+
#include "litert/vendors/qualcomm/core/builders/split_op_builder.h"
16+
#include "litert/vendors/qualcomm/core/tensor_pool.h"
17+
#include "litert/vendors/qualcomm/core/wrappers/op_wrapper.h"
18+
#include "litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h"
19+
20+
namespace {
21+
constexpr size_t kMulIndexIndex = 0;
22+
constexpr size_t kTransposeIndex = 1;
23+
constexpr size_t kReshapeIndex = 2;
24+
constexpr size_t kMatMulIndex = 3;
25+
constexpr size_t kAddIndex = 4;
26+
constexpr size_t kSoftmaxIndex = 5;
27+
constexpr size_t kMatMul2Index = 6;
28+
constexpr size_t kReshape2Index = 7;
29+
constexpr size_t kTranspose2Index = 8;
30+
constexpr size_t kReshape3Index = 9;
31+
} // namespace
32+
namespace qnn {
33+
// TODO (chunhsue-qti): merge similar utililty function
34+
OpWrapper& EmplaceOpWithIO(
35+
std::vector<OpWrapper>& new_ops, const OpWrapper& source_op,
36+
const std::vector<std::optional<qnn::TensorWrapperRef>>& inputs,
37+
const std::vector<std::optional<qnn::TensorWrapperRef>>& outputs) {
38+
auto& ret = new_ops.emplace_back(source_op);
39+
ret.UpdateTensors(inputs, outputs);
40+
return ret;
41+
}
42+
43+
TensorWrapper& BuildSingleSHA(std::vector<OpWrapper>& new_ops,
44+
TensorPool& tensor_pool, TensorWrapper& sha_input,
45+
TensorWrapper& mask_input,
46+
const OpWrapper& mul_op,
47+
const OpWrapper& matmul_op1,
48+
const OpWrapper& add_op,
49+
const OpWrapper& softmax_op,
50+
const OpWrapper& matmul_op2, size_t num_heads) {
51+
// Mul
52+
auto& mul_output = tensor_pool.CloneNativeTensorFrom(
53+
mul_op.GetOutputTensor(0), sha_input.GetDims());
54+
55+
EmplaceOpWithIO(new_ops, mul_op, {sha_input, std::nullopt}, {mul_output});
56+
57+
// MatMul 1
58+
const auto& matmul_op1_output = matmul_op1.GetOutputTensor(0);
59+
std::vector<uint32_t> new_matmul1_output_dim = matmul_op1_output.GetDims();
60+
new_matmul1_output_dim[2] /= num_heads;
61+
auto& new_matmul1_output = tensor_pool.CloneNativeTensorFrom(
62+
matmul_op1_output, new_matmul1_output_dim);
63+
EmplaceOpWithIO(
64+
new_ops, matmul_op1, {mul_output, std::nullopt}, {new_matmul1_output});
65+
66+
// Add
67+
auto& new_add_output = tensor_pool.CloneNativeTensorFrom(
68+
add_op.GetOutputTensor(0), new_matmul1_output_dim);
69+
EmplaceOpWithIO(
70+
new_ops, add_op, {new_matmul1_output, mask_input}, {new_add_output});
71+
72+
// Softmax
73+
auto& softmax_output = tensor_pool.CloneNativeTensorFrom(
74+
softmax_op.GetOutputTensor(0), new_add_output.GetDims());
75+
EmplaceOpWithIO(new_ops, softmax_op, {new_add_output}, {softmax_output});
76+
77+
// MatMul 2
78+
auto matmul_op2_out_dim = matmul_op2.GetOutputTensor(0).GetDims();
79+
matmul_op2_out_dim[2] /= num_heads;
80+
auto& new_matmul2_output = tensor_pool.CloneNativeTensorFrom(
81+
matmul_op2.GetOutputTensor(0), matmul_op2_out_dim);
82+
EmplaceOpWithIO(new_ops, matmul_op2, {softmax_output, std::nullopt},
83+
{new_matmul2_output});
84+
return new_matmul2_output;
85+
}
86+
87+
// TODO (chunhsue-qti): add namespace to each new op
88+
std::vector<OpWrapper> MHA2SHA(TensorPool& tensor_pool, const OpWrapper& mul_op,
89+
const OpWrapper& tranpose_op1,
90+
const OpWrapper& matmul_op1,
91+
const OpWrapper& add_op,
92+
const OpWrapper& softmax_op,
93+
const OpWrapper& matmul_op2,
94+
const TensorWrapper& pattern_input,
95+
const TensorWrapper& pattern_output) {
96+
std::vector<OpWrapper> new_ops;
97+
98+
// Transpose
99+
auto transpose_out_dims = tranpose_op1.GetOutputTensor(0).GetDims();
100+
auto& transpose_output =
101+
tensor_pool.CloneNativeTensorFrom(pattern_input, transpose_out_dims);
102+
auto& new_transpose1 = EmplaceOpWithIO(
103+
new_ops, tranpose_op1, {const_cast<::qnn::TensorWrapper&>(pattern_input)},
104+
{transpose_output});
105+
106+
const uint32_t num_heads = pattern_input.GetDim(2);
107+
const auto& mha_input = new_transpose1.GetOutputTensor(0); // split_in
108+
109+
std::vector<::qnn::TensorWrapperRef> sha_inputs;
110+
sha_inputs.reserve(num_heads);
111+
for (size_t i = 0; i < num_heads; i++) {
112+
auto sha_input_dims = mha_input.GetDims(); // split_out_dims
113+
sha_input_dims[1] /= num_heads;
114+
auto& split_output =
115+
tensor_pool.CloneNativeTensorFrom(mha_input, sha_input_dims);
116+
sha_inputs.emplace_back(split_output);
117+
}
118+
119+
// split from mul
120+
const std::array<int32_t, 1> split_axis_data{1};
121+
auto& split_axis = tensor_pool.CreateStaticTensor(
122+
QNN_DATATYPE_INT_32, {}, {split_axis_data.size()},
123+
split_axis_data.size() * sizeof(split_axis_data[0]),
124+
split_axis_data.data());
125+
auto split_op = BuildSplitOp(
126+
tensor_pool, {split_axis, const_cast<TensorWrapper&>(mha_input)},
127+
sha_inputs, num_heads);
128+
129+
std::move(split_op.begin(), split_op.end(), std::back_inserter(new_ops));
130+
131+
// split from mask
132+
auto& mask_input = add_op.GetInputTensor(1);
133+
std::vector<::qnn::TensorWrapperRef> new_mask_inputs;
134+
new_mask_inputs.reserve(num_heads);
135+
for (size_t i = 0; i < num_heads; i++) {
136+
auto new_mask_input_dims = mask_input.GetDims();
137+
new_mask_input_dims[2] /= num_heads;
138+
auto& mask_split_output =
139+
tensor_pool.CloneNativeTensorFrom(mask_input, new_mask_input_dims);
140+
new_mask_inputs.emplace_back(mask_split_output);
141+
}
142+
143+
const std::array<int32_t, 1> mask_split_axis_data{2};
144+
auto& mask_split_axis = tensor_pool.CreateStaticTensor(
145+
QNN_DATATYPE_INT_32, {}, {mask_split_axis_data.size()},
146+
mask_split_axis_data.size() * sizeof(mask_split_axis_data[0]),
147+
mask_split_axis_data.data());
148+
auto mask_split_op = BuildSplitOp(
149+
tensor_pool, {mask_split_axis, const_cast<TensorWrapper&>(mask_input)},
150+
new_mask_inputs, num_heads);
151+
152+
std::move(mask_split_op.begin(), mask_split_op.end(),
153+
std::back_inserter(new_ops));
154+
155+
std::vector<TensorWrapperRef> sha_outputs;
156+
sha_outputs.reserve(num_heads);
157+
for (size_t i = 0; i < num_heads; ++i) {
158+
sha_outputs.emplace_back(BuildSingleSHA(
159+
new_ops, tensor_pool, const_cast<TensorWrapper&>(sha_inputs[i].get()),
160+
const_cast<TensorWrapper&>(new_mask_inputs[i].get()), mul_op,
161+
matmul_op1, add_op, softmax_op, matmul_op2, num_heads));
162+
}
163+
164+
// Concat
165+
auto concat_dims = pattern_output.GetDims();
166+
concat_dims.insert(concat_dims.begin(), 1);
167+
auto& concat_output =
168+
tensor_pool.CloneNativeTensorFrom(pattern_output, concat_dims);
169+
auto concat_final =
170+
BuildConcatenationOp(tensor_pool, sha_outputs, {concat_output}, 3);
171+
std::move(concat_final.begin(), concat_final.end(),
172+
std::back_inserter(new_ops));
173+
// Reshape
174+
auto reshape =
175+
BuildReshapeOp(tensor_pool, {concat_output},
176+
{const_cast<::qnn::TensorWrapper&>(pattern_output)});
177+
std::move(reshape.begin(), reshape.end(), std::back_inserter(new_ops));
178+
return new_ops;
179+
}
180+
181+
size_t TransformEmbeddingGemma(
182+
std::function<bool(OpWrapper&)> validate_op_config,
183+
std::vector<OpWrapper>& ops, size_t start_index, TensorPool& tensor_pool,
184+
size_t pattern_size) {
185+
// Connection check
186+
auto is_op_connected = [](const OpWrapper& op1,
187+
const OpWrapper& op2) -> bool {
188+
return op1.GetOutputTensor(0) == op2.GetInputTensor(0);
189+
};
190+
191+
const auto& mul_op = ops[start_index + kMulIndexIndex];
192+
const auto& tranpose_op1 = ops[start_index + kTransposeIndex];
193+
const auto& reshape_op1 = ops[start_index + kReshapeIndex];
194+
const auto& matmul_op1 = ops[start_index + kMatMulIndex];
195+
const auto& add_op = ops[start_index + kAddIndex];
196+
const auto& softmax_op = ops[start_index + kSoftmaxIndex];
197+
const auto& matmul_op2 = ops[start_index + kMatMul2Index];
198+
const auto& reshape_op2 = ops[start_index + kReshape2Index];
199+
const auto& transpose_op2 = ops[start_index + kTranspose2Index];
200+
const auto& reshape_op3 = ops[start_index + kReshape3Index];
201+
202+
bool is_match = is_op_connected(mul_op, tranpose_op1) &&
203+
is_op_connected(tranpose_op1, reshape_op1) &&
204+
is_op_connected(reshape_op1, matmul_op1) &&
205+
is_op_connected(matmul_op1, add_op) &&
206+
is_op_connected(add_op, softmax_op) &&
207+
is_op_connected(softmax_op, matmul_op2) &&
208+
is_op_connected(matmul_op2, reshape_op2) &&
209+
is_op_connected(reshape_op2, transpose_op2) &&
210+
is_op_connected(transpose_op2, reshape_op3);
211+
if (!is_match) {
212+
return 1;
213+
}
214+
// Graph transform
215+
QNN_LOG_INFO("[G2G] Transforming MHA to SHA in Embedding Gemma");
216+
// Construct the new subgraph
217+
const auto& pattern_input = mul_op.GetInputTensor(0);
218+
const auto& pattern_output = reshape_op3.GetOutputTensor(0);
219+
auto new_ops = MHA2SHA(tensor_pool, mul_op, tranpose_op1, matmul_op1, add_op,
220+
softmax_op, matmul_op2, pattern_input, pattern_output);
221+
if (new_ops.empty()) {
222+
QNN_LOG_WARNING(
223+
"[G2G] Transformation failed. Rolling back to the original graph.");
224+
return 1;
225+
}
226+
// Validate new graph.
227+
bool is_valid =
228+
std::all_of(new_ops.begin(), new_ops.end(),
229+
[validate_op_config](::qnn::OpWrapper& op_wrapper) -> bool {
230+
return validate_op_config(op_wrapper);
231+
});
232+
if (is_valid) {
233+
// Adjust the name to avoid a name collision in the Qnn JSON dump.
234+
for (size_t i = 0; i < new_ops.size(); ++i) {
235+
new_ops[i].AddSuffixToName(absl::StrCat("_qcg2g_", i));
236+
}
237+
// Replace the matched pattern with a newly generated subgraph.
238+
size_t step_size = new_ops.size();
239+
ops.insert(ops.begin() + start_index + pattern_size,
240+
std::make_move_iterator(new_ops.begin()),
241+
std::make_move_iterator(new_ops.end()));
242+
ops.erase(ops.begin() + start_index,
243+
ops.begin() + start_index + pattern_size);
244+
QNN_LOG_INFO("[G2G] Done transforming MHA to SHA in Embedding Gemma!");
245+
return step_size;
246+
}
247+
QNN_LOG_WARNING(
248+
"[G2G] Validation failed. Rolling back to the original graph.");
249+
return 1;
250+
}
251+
252+
} // namespace qnn
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#ifndef ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_
5+
#define ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_
6+
7+
#include <cstddef>
8+
#include <functional>
9+
#include <vector>
10+
11+
#include "litert/vendors/qualcomm/core/tensor_pool.h"
12+
#include "litert/vendors/qualcomm/core/wrappers/op_wrapper.h"
13+
14+
namespace qnn {
15+
size_t TransformEmbeddingGemma(
16+
std::function<bool(OpWrapper&)> validate_op_config,
17+
std::vector<OpWrapper>& ops, size_t start_index, TensorPool& tensor_pool,
18+
size_t pattern_size);
19+
} // namespace qnn
20+
#endif // ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_

0 commit comments

Comments
 (0)