Skip to content

Commit d05c94f

Browse files
Merge pull request #3922 from graham0824:dev/chunhsue/2_38_embedding_gemma
LiteRT-PiperOrigin-RevId: 828693184
2 parents 35f45df + 3fd1d26 commit d05c94f

File tree

6 files changed

+553
-1
lines changed

6 files changed

+553
-1
lines changed

litert/vendors/qualcomm/core/transformation/BUILD

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ cc_library(
4848
"nobuilder",
4949
],
5050
deps = [
51+
":embedding_gemma",
5152
":mask",
5253
":matmul_convert",
5354
":mha_to_sha",
@@ -110,3 +111,53 @@ cc_library(
110111
"@qairt//:qnn_lib_headers",
111112
],
112113
)
114+
115+
cc_library(
116+
name = "embedding_gemma",
117+
srcs = ["embedding_gemma.cc"],
118+
hdrs = ["embedding_gemma.h"],
119+
tags = [
120+
# Don't build/test in OS until qnn is available.
121+
"nobuilder",
122+
],
123+
deps = [
124+
"//litert/vendors/qualcomm/core:tensor_pool",
125+
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
126+
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
127+
"//litert/vendors/qualcomm/core/builders:split_op_builder",
128+
"//litert/vendors/qualcomm/core/utils:log",
129+
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",
130+
"//litert/vendors/qualcomm/core/wrappers:tensor_wrapper",
131+
"@com_google_absl//absl/strings",
132+
"@qairt//:qnn_lib_headers",
133+
],
134+
)
135+
136+
cc_test(
137+
name = "embedding_gemma_test",
138+
srcs = [
139+
"embedding_gemma_test.cc",
140+
],
141+
tags = [
142+
# Don't build/test in OS until qnn is available.
143+
"nobuilder",
144+
],
145+
deps = [
146+
":graph_to_graph",
147+
"//litert/vendors/qualcomm/core:op_code",
148+
"//litert/vendors/qualcomm/core:tensor_pool",
149+
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
150+
"//litert/vendors/qualcomm/core/builders:elementwise_op_builder",
151+
"//litert/vendors/qualcomm/core/builders:matmul_op_builder",
152+
"//litert/vendors/qualcomm/core/builders:op_builder",
153+
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
154+
"//litert/vendors/qualcomm/core/builders:softmax_op_builder",
155+
"//litert/vendors/qualcomm/core/builders:split_op_builder",
156+
"//litert/vendors/qualcomm/core/builders:transpose_op_builder",
157+
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",
158+
"//litert/vendors/qualcomm/core/wrappers:quantize_params_wrapper",
159+
"//litert/vendors/qualcomm/core/wrappers:tensor_wrapper",
160+
"@com_google_googletest//:gtest_main",
161+
"@qairt//:qnn_lib_headers",
162+
],
163+
)
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "litert/vendors/qualcomm/core/transformation/embedding_gemma.h"
5+
6+
#include <algorithm>
7+
#include <array>
8+
#include <cstddef>
9+
#include <cstdint>
10+
#include <functional>
11+
#include <iterator>
12+
#include <optional>
13+
#include <vector>
14+
15+
#include "absl/strings/str_cat.h" // from @com_google_absl
16+
#include "litert/vendors/qualcomm/core/builders/concatenation_op_builder.h"
17+
#include "litert/vendors/qualcomm/core/builders/reshape_op_builder.h"
18+
#include "litert/vendors/qualcomm/core/builders/split_op_builder.h"
19+
#include "litert/vendors/qualcomm/core/tensor_pool.h"
20+
#include "litert/vendors/qualcomm/core/utils/log.h"
21+
#include "litert/vendors/qualcomm/core/wrappers/op_wrapper.h"
22+
#include "litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h"
23+
#include "QnnTypes.h" // from @qairt
24+
25+
namespace {
26+
constexpr size_t kMulIndexIndex = 0;
27+
constexpr size_t kTransposeIndex = 1;
28+
constexpr size_t kReshapeIndex = 2;
29+
constexpr size_t kMatMulIndex = 3;
30+
constexpr size_t kAddIndex = 4;
31+
constexpr size_t kSoftmaxIndex = 5;
32+
constexpr size_t kMatMul2Index = 6;
33+
constexpr size_t kReshape2Index = 7;
34+
constexpr size_t kTranspose2Index = 8;
35+
constexpr size_t kReshape3Index = 9;
36+
} // namespace
37+
namespace qnn {
38+
// TODO (chunhsue-qti): merge similar utility function
39+
OpWrapper& EmplaceOpWithIO(
40+
std::vector<OpWrapper>& new_ops, const OpWrapper& source_op,
41+
const std::vector<std::optional<qnn::TensorWrapperRef>>& inputs,
42+
const std::vector<std::optional<qnn::TensorWrapperRef>>& outputs) {
43+
auto& ret = new_ops.emplace_back(source_op);
44+
ret.UpdateTensors(inputs, outputs);
45+
return ret;
46+
}
47+
48+
TensorWrapper& BuildSingleSHA(std::vector<OpWrapper>& new_ops,
49+
TensorPool& tensor_pool, TensorWrapper& sha_input,
50+
TensorWrapper& mask_input,
51+
const OpWrapper& mul_op,
52+
const OpWrapper& matmul_op1,
53+
const OpWrapper& add_op,
54+
const OpWrapper& softmax_op,
55+
const OpWrapper& matmul_op2, size_t num_heads) {
56+
// Mul
57+
auto& mul_output = tensor_pool.CloneNativeTensorFrom(
58+
mul_op.GetOutputTensor(0), sha_input.GetDims());
59+
60+
EmplaceOpWithIO(new_ops, mul_op, {sha_input, std::nullopt}, {mul_output});
61+
62+
// MatMul 1
63+
const auto& matmul_op1_output = matmul_op1.GetOutputTensor(0);
64+
std::vector<uint32_t> new_matmul1_output_dim = matmul_op1_output.GetDims();
65+
new_matmul1_output_dim[2] /= num_heads;
66+
auto& new_matmul1_output = tensor_pool.CloneNativeTensorFrom(
67+
matmul_op1_output, new_matmul1_output_dim);
68+
EmplaceOpWithIO(new_ops, matmul_op1, {mul_output, std::nullopt},
69+
{new_matmul1_output});
70+
71+
// Add
72+
auto& new_add_output = tensor_pool.CloneNativeTensorFrom(
73+
add_op.GetOutputTensor(0), new_matmul1_output_dim);
74+
EmplaceOpWithIO(new_ops, add_op, {new_matmul1_output, mask_input},
75+
{new_add_output});
76+
77+
// Softmax
78+
auto& softmax_output = tensor_pool.CloneNativeTensorFrom(
79+
softmax_op.GetOutputTensor(0), new_add_output.GetDims());
80+
EmplaceOpWithIO(new_ops, softmax_op, {new_add_output}, {softmax_output});
81+
82+
// MatMul 2
83+
auto matmul_op2_out_dim = matmul_op2.GetOutputTensor(0).GetDims();
84+
matmul_op2_out_dim[2] /= num_heads;
85+
auto& new_matmul2_output = tensor_pool.CloneNativeTensorFrom(
86+
matmul_op2.GetOutputTensor(0), matmul_op2_out_dim);
87+
EmplaceOpWithIO(new_ops, matmul_op2, {softmax_output, std::nullopt},
88+
{new_matmul2_output});
89+
return new_matmul2_output;
90+
}
91+
92+
// TODO (chunhsue-qti): add namespace to each new op
93+
std::vector<OpWrapper> MHA2SHA(TensorPool& tensor_pool, const OpWrapper& mul_op,
94+
const OpWrapper& tranpose_op1,
95+
const OpWrapper& matmul_op1,
96+
const OpWrapper& add_op,
97+
const OpWrapper& softmax_op,
98+
const OpWrapper& matmul_op2,
99+
const TensorWrapper& pattern_input,
100+
const TensorWrapper& pattern_output) {
101+
std::vector<OpWrapper> new_ops;
102+
103+
// Transpose
104+
auto transpose_out_dims = tranpose_op1.GetOutputTensor(0).GetDims();
105+
auto& transpose_output =
106+
tensor_pool.CloneNativeTensorFrom(pattern_input, transpose_out_dims);
107+
auto& new_transpose1 = EmplaceOpWithIO(
108+
new_ops, tranpose_op1, {const_cast<::qnn::TensorWrapper&>(pattern_input)},
109+
{transpose_output});
110+
111+
const uint32_t num_heads = pattern_input.GetDim(2);
112+
const auto& mha_input = new_transpose1.GetOutputTensor(0); // split_in
113+
114+
std::vector<::qnn::TensorWrapperRef> sha_inputs;
115+
sha_inputs.reserve(num_heads);
116+
for (size_t i = 0; i < num_heads; i++) {
117+
auto sha_input_dims = mha_input.GetDims(); // split_out_dims
118+
sha_input_dims[1] /= num_heads;
119+
auto& split_output =
120+
tensor_pool.CloneNativeTensorFrom(mha_input, sha_input_dims);
121+
sha_inputs.emplace_back(split_output);
122+
}
123+
124+
// split from mul
125+
const std::array<int32_t, 1> split_axis_data{1};
126+
auto& split_axis = tensor_pool.CreateStaticTensor(
127+
QNN_DATATYPE_INT_32, {}, {split_axis_data.size()},
128+
split_axis_data.size() * sizeof(split_axis_data[0]),
129+
split_axis_data.data());
130+
auto split_op = BuildSplitOp(
131+
tensor_pool, {split_axis, const_cast<TensorWrapper&>(mha_input)},
132+
sha_inputs, num_heads);
133+
134+
std::move(split_op.begin(), split_op.end(), std::back_inserter(new_ops));
135+
136+
// split from mask
137+
auto& mask_input = add_op.GetInputTensor(1);
138+
std::vector<::qnn::TensorWrapperRef> new_mask_inputs;
139+
new_mask_inputs.reserve(num_heads);
140+
for (size_t i = 0; i < num_heads; i++) {
141+
auto new_mask_input_dims = mask_input.GetDims();
142+
new_mask_input_dims[2] /= num_heads;
143+
auto& mask_split_output =
144+
tensor_pool.CloneNativeTensorFrom(mask_input, new_mask_input_dims);
145+
new_mask_inputs.emplace_back(mask_split_output);
146+
}
147+
148+
const std::array<int32_t, 1> mask_split_axis_data{2};
149+
auto& mask_split_axis = tensor_pool.CreateStaticTensor(
150+
QNN_DATATYPE_INT_32, {}, {mask_split_axis_data.size()},
151+
mask_split_axis_data.size() * sizeof(mask_split_axis_data[0]),
152+
mask_split_axis_data.data());
153+
auto mask_split_op = BuildSplitOp(
154+
tensor_pool, {mask_split_axis, const_cast<TensorWrapper&>(mask_input)},
155+
new_mask_inputs, num_heads);
156+
157+
std::move(mask_split_op.begin(), mask_split_op.end(),
158+
std::back_inserter(new_ops));
159+
160+
std::vector<TensorWrapperRef> sha_outputs;
161+
sha_outputs.reserve(num_heads);
162+
for (size_t i = 0; i < num_heads; ++i) {
163+
sha_outputs.emplace_back(BuildSingleSHA(
164+
new_ops, tensor_pool, const_cast<TensorWrapper&>(sha_inputs[i].get()),
165+
const_cast<TensorWrapper&>(new_mask_inputs[i].get()), mul_op,
166+
matmul_op1, add_op, softmax_op, matmul_op2, num_heads));
167+
}
168+
169+
// Concat
170+
auto concat_dims = pattern_output.GetDims();
171+
concat_dims.insert(concat_dims.begin(), 1);
172+
auto& concat_output =
173+
tensor_pool.CloneNativeTensorFrom(pattern_output, concat_dims);
174+
auto concat_final =
175+
BuildConcatenationOp(tensor_pool, sha_outputs, {concat_output}, 3);
176+
std::move(concat_final.begin(), concat_final.end(),
177+
std::back_inserter(new_ops));
178+
// Reshape
179+
auto reshape =
180+
BuildReshapeOp(tensor_pool, {concat_output},
181+
{const_cast<::qnn::TensorWrapper&>(pattern_output)});
182+
std::move(reshape.begin(), reshape.end(), std::back_inserter(new_ops));
183+
return new_ops;
184+
}
185+
186+
size_t TransformEmbeddingGemma(
187+
std::function<bool(OpWrapper&)> validate_op_config,
188+
std::vector<OpWrapper>& ops, size_t start_index, TensorPool& tensor_pool,
189+
size_t pattern_size) {
190+
// Connection check
191+
auto is_op_connected = [](const OpWrapper& op1,
192+
const OpWrapper& op2) -> bool {
193+
return op1.GetOutputTensor(0) == op2.GetInputTensor(0);
194+
};
195+
196+
const auto& mul_op = ops[start_index + kMulIndexIndex];
197+
const auto& tranpose_op1 = ops[start_index + kTransposeIndex];
198+
const auto& reshape_op1 = ops[start_index + kReshapeIndex];
199+
const auto& matmul_op1 = ops[start_index + kMatMulIndex];
200+
const auto& add_op = ops[start_index + kAddIndex];
201+
const auto& softmax_op = ops[start_index + kSoftmaxIndex];
202+
const auto& matmul_op2 = ops[start_index + kMatMul2Index];
203+
const auto& reshape_op2 = ops[start_index + kReshape2Index];
204+
const auto& transpose_op2 = ops[start_index + kTranspose2Index];
205+
const auto& reshape_op3 = ops[start_index + kReshape3Index];
206+
207+
bool is_match = is_op_connected(mul_op, tranpose_op1) &&
208+
is_op_connected(tranpose_op1, reshape_op1) &&
209+
is_op_connected(reshape_op1, matmul_op1) &&
210+
is_op_connected(matmul_op1, add_op) &&
211+
is_op_connected(add_op, softmax_op) &&
212+
is_op_connected(softmax_op, matmul_op2) &&
213+
is_op_connected(matmul_op2, reshape_op2) &&
214+
is_op_connected(reshape_op2, transpose_op2) &&
215+
is_op_connected(transpose_op2, reshape_op3);
216+
if (!is_match) {
217+
return 1;
218+
}
219+
// Graph transform
220+
QNN_LOG_INFO("[G2G] Transforming MHA to SHA in Embedding Gemma");
221+
// Construct the new subgraph
222+
const auto& pattern_input = mul_op.GetInputTensor(0);
223+
const auto& pattern_output = reshape_op3.GetOutputTensor(0);
224+
auto new_ops = MHA2SHA(tensor_pool, mul_op, tranpose_op1, matmul_op1, add_op,
225+
softmax_op, matmul_op2, pattern_input, pattern_output);
226+
if (new_ops.empty()) {
227+
QNN_LOG_WARNING(
228+
"[G2G] Transformation failed. Rolling back to the original graph.");
229+
return 1;
230+
}
231+
// Validate new graph.
232+
bool is_valid =
233+
std::all_of(new_ops.begin(), new_ops.end(),
234+
[validate_op_config](::qnn::OpWrapper& op_wrapper) -> bool {
235+
return validate_op_config(op_wrapper);
236+
});
237+
if (is_valid) {
238+
// Adjust the name to avoid a name collision in the Qnn JSON dump.
239+
for (size_t i = 0; i < new_ops.size(); ++i) {
240+
new_ops[i].AddSuffixToName(absl::StrCat("_qcg2g_", i));
241+
}
242+
// Replace the matched pattern with a newly generated subgraph.
243+
size_t step_size = new_ops.size();
244+
ops.insert(ops.begin() + start_index + pattern_size,
245+
std::make_move_iterator(new_ops.begin()),
246+
std::make_move_iterator(new_ops.end()));
247+
ops.erase(ops.begin() + start_index,
248+
ops.begin() + start_index + pattern_size);
249+
QNN_LOG_INFO("[G2G] Done transforming MHA to SHA in Embedding Gemma!");
250+
return step_size;
251+
}
252+
QNN_LOG_WARNING(
253+
"[G2G] Validation failed. Rolling back to the original graph.");
254+
return 1;
255+
}
256+
257+
} // namespace qnn
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#ifndef ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_
5+
#define ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_
6+
7+
#include <cstddef>
8+
#include <functional>
9+
#include <vector>
10+
11+
#include "litert/vendors/qualcomm/core/tensor_pool.h"
12+
#include "litert/vendors/qualcomm/core/wrappers/op_wrapper.h"
13+
14+
namespace qnn {
15+
size_t TransformEmbeddingGemma(
16+
std::function<bool(OpWrapper&)> validate_op_config,
17+
std::vector<OpWrapper>& ops, size_t start_index, TensorPool& tensor_pool,
18+
size_t pattern_size);
19+
} // namespace qnn
20+
#endif // ODML_LITERT_LITERT_VENDORS_QUALCOMM_CORE_TRANSFORMATION_EMBEDDING_GEMMA_H_

0 commit comments

Comments
 (0)