|
| 1 | +// Copyright (c) Qualcomm Innovation Center, Inc. All Rights Reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +#include "litert/vendors/qualcomm/core/transformation/embedding_gemma.h" |
| 5 | + |
| 6 | +#include <algorithm> |
| 7 | +#include <array> |
| 8 | +#include <cstddef> |
| 9 | +#include <cstdint> |
| 10 | +#include <functional> |
| 11 | +#include <iterator> |
| 12 | +#include <optional> |
| 13 | +#include <vector> |
| 14 | + |
| 15 | +#include "absl/strings/str_cat.h" // from @com_google_absl |
| 16 | +#include "litert/vendors/qualcomm/core/builders/concatenation_op_builder.h" |
| 17 | +#include "litert/vendors/qualcomm/core/builders/reshape_op_builder.h" |
| 18 | +#include "litert/vendors/qualcomm/core/builders/split_op_builder.h" |
| 19 | +#include "litert/vendors/qualcomm/core/tensor_pool.h" |
| 20 | +#include "litert/vendors/qualcomm/core/utils/log.h" |
| 21 | +#include "litert/vendors/qualcomm/core/wrappers/op_wrapper.h" |
| 22 | +#include "litert/vendors/qualcomm/core/wrappers/tensor_wrapper.h" |
| 23 | +#include "QnnTypes.h" // from @qairt |
| 24 | + |
| 25 | +namespace { |
| 26 | +constexpr size_t kMulIndexIndex = 0; |
| 27 | +constexpr size_t kTransposeIndex = 1; |
| 28 | +constexpr size_t kReshapeIndex = 2; |
| 29 | +constexpr size_t kMatMulIndex = 3; |
| 30 | +constexpr size_t kAddIndex = 4; |
| 31 | +constexpr size_t kSoftmaxIndex = 5; |
| 32 | +constexpr size_t kMatMul2Index = 6; |
| 33 | +constexpr size_t kReshape2Index = 7; |
| 34 | +constexpr size_t kTranspose2Index = 8; |
| 35 | +constexpr size_t kReshape3Index = 9; |
| 36 | +} // namespace |
| 37 | +namespace qnn { |
| 38 | +// TODO (chunhsue-qti): merge similar utility function |
| 39 | +OpWrapper& EmplaceOpWithIO( |
| 40 | + std::vector<OpWrapper>& new_ops, const OpWrapper& source_op, |
| 41 | + const std::vector<std::optional<qnn::TensorWrapperRef>>& inputs, |
| 42 | + const std::vector<std::optional<qnn::TensorWrapperRef>>& outputs) { |
| 43 | + auto& ret = new_ops.emplace_back(source_op); |
| 44 | + ret.UpdateTensors(inputs, outputs); |
| 45 | + return ret; |
| 46 | +} |
| 47 | + |
| 48 | +TensorWrapper& BuildSingleSHA(std::vector<OpWrapper>& new_ops, |
| 49 | + TensorPool& tensor_pool, TensorWrapper& sha_input, |
| 50 | + TensorWrapper& mask_input, |
| 51 | + const OpWrapper& mul_op, |
| 52 | + const OpWrapper& matmul_op1, |
| 53 | + const OpWrapper& add_op, |
| 54 | + const OpWrapper& softmax_op, |
| 55 | + const OpWrapper& matmul_op2, size_t num_heads) { |
| 56 | + // Mul |
| 57 | + auto& mul_output = tensor_pool.CloneNativeTensorFrom( |
| 58 | + mul_op.GetOutputTensor(0), sha_input.GetDims()); |
| 59 | + |
| 60 | + EmplaceOpWithIO(new_ops, mul_op, {sha_input, std::nullopt}, {mul_output}); |
| 61 | + |
| 62 | + // MatMul 1 |
| 63 | + const auto& matmul_op1_output = matmul_op1.GetOutputTensor(0); |
| 64 | + std::vector<uint32_t> new_matmul1_output_dim = matmul_op1_output.GetDims(); |
| 65 | + new_matmul1_output_dim[2] /= num_heads; |
| 66 | + auto& new_matmul1_output = tensor_pool.CloneNativeTensorFrom( |
| 67 | + matmul_op1_output, new_matmul1_output_dim); |
| 68 | + EmplaceOpWithIO(new_ops, matmul_op1, {mul_output, std::nullopt}, |
| 69 | + {new_matmul1_output}); |
| 70 | + |
| 71 | + // Add |
| 72 | + auto& new_add_output = tensor_pool.CloneNativeTensorFrom( |
| 73 | + add_op.GetOutputTensor(0), new_matmul1_output_dim); |
| 74 | + EmplaceOpWithIO(new_ops, add_op, {new_matmul1_output, mask_input}, |
| 75 | + {new_add_output}); |
| 76 | + |
| 77 | + // Softmax |
| 78 | + auto& softmax_output = tensor_pool.CloneNativeTensorFrom( |
| 79 | + softmax_op.GetOutputTensor(0), new_add_output.GetDims()); |
| 80 | + EmplaceOpWithIO(new_ops, softmax_op, {new_add_output}, {softmax_output}); |
| 81 | + |
| 82 | + // MatMul 2 |
| 83 | + auto matmul_op2_out_dim = matmul_op2.GetOutputTensor(0).GetDims(); |
| 84 | + matmul_op2_out_dim[2] /= num_heads; |
| 85 | + auto& new_matmul2_output = tensor_pool.CloneNativeTensorFrom( |
| 86 | + matmul_op2.GetOutputTensor(0), matmul_op2_out_dim); |
| 87 | + EmplaceOpWithIO(new_ops, matmul_op2, {softmax_output, std::nullopt}, |
| 88 | + {new_matmul2_output}); |
| 89 | + return new_matmul2_output; |
| 90 | +} |
| 91 | + |
| 92 | +// TODO (chunhsue-qti): add namespace to each new op |
| 93 | +std::vector<OpWrapper> MHA2SHA(TensorPool& tensor_pool, const OpWrapper& mul_op, |
| 94 | + const OpWrapper& tranpose_op1, |
| 95 | + const OpWrapper& matmul_op1, |
| 96 | + const OpWrapper& add_op, |
| 97 | + const OpWrapper& softmax_op, |
| 98 | + const OpWrapper& matmul_op2, |
| 99 | + const TensorWrapper& pattern_input, |
| 100 | + const TensorWrapper& pattern_output) { |
| 101 | + std::vector<OpWrapper> new_ops; |
| 102 | + |
| 103 | + // Transpose |
| 104 | + auto transpose_out_dims = tranpose_op1.GetOutputTensor(0).GetDims(); |
| 105 | + auto& transpose_output = |
| 106 | + tensor_pool.CloneNativeTensorFrom(pattern_input, transpose_out_dims); |
| 107 | + auto& new_transpose1 = EmplaceOpWithIO( |
| 108 | + new_ops, tranpose_op1, {const_cast<::qnn::TensorWrapper&>(pattern_input)}, |
| 109 | + {transpose_output}); |
| 110 | + |
| 111 | + const uint32_t num_heads = pattern_input.GetDim(2); |
| 112 | + const auto& mha_input = new_transpose1.GetOutputTensor(0); // split_in |
| 113 | + |
| 114 | + std::vector<::qnn::TensorWrapperRef> sha_inputs; |
| 115 | + sha_inputs.reserve(num_heads); |
| 116 | + for (size_t i = 0; i < num_heads; i++) { |
| 117 | + auto sha_input_dims = mha_input.GetDims(); // split_out_dims |
| 118 | + sha_input_dims[1] /= num_heads; |
| 119 | + auto& split_output = |
| 120 | + tensor_pool.CloneNativeTensorFrom(mha_input, sha_input_dims); |
| 121 | + sha_inputs.emplace_back(split_output); |
| 122 | + } |
| 123 | + |
| 124 | + // split from mul |
| 125 | + const std::array<int32_t, 1> split_axis_data{1}; |
| 126 | + auto& split_axis = tensor_pool.CreateStaticTensor( |
| 127 | + QNN_DATATYPE_INT_32, {}, {split_axis_data.size()}, |
| 128 | + split_axis_data.size() * sizeof(split_axis_data[0]), |
| 129 | + split_axis_data.data()); |
| 130 | + auto split_op = BuildSplitOp( |
| 131 | + tensor_pool, {split_axis, const_cast<TensorWrapper&>(mha_input)}, |
| 132 | + sha_inputs, num_heads); |
| 133 | + |
| 134 | + std::move(split_op.begin(), split_op.end(), std::back_inserter(new_ops)); |
| 135 | + |
| 136 | + // split from mask |
| 137 | + auto& mask_input = add_op.GetInputTensor(1); |
| 138 | + std::vector<::qnn::TensorWrapperRef> new_mask_inputs; |
| 139 | + new_mask_inputs.reserve(num_heads); |
| 140 | + for (size_t i = 0; i < num_heads; i++) { |
| 141 | + auto new_mask_input_dims = mask_input.GetDims(); |
| 142 | + new_mask_input_dims[2] /= num_heads; |
| 143 | + auto& mask_split_output = |
| 144 | + tensor_pool.CloneNativeTensorFrom(mask_input, new_mask_input_dims); |
| 145 | + new_mask_inputs.emplace_back(mask_split_output); |
| 146 | + } |
| 147 | + |
| 148 | + const std::array<int32_t, 1> mask_split_axis_data{2}; |
| 149 | + auto& mask_split_axis = tensor_pool.CreateStaticTensor( |
| 150 | + QNN_DATATYPE_INT_32, {}, {mask_split_axis_data.size()}, |
| 151 | + mask_split_axis_data.size() * sizeof(mask_split_axis_data[0]), |
| 152 | + mask_split_axis_data.data()); |
| 153 | + auto mask_split_op = BuildSplitOp( |
| 154 | + tensor_pool, {mask_split_axis, const_cast<TensorWrapper&>(mask_input)}, |
| 155 | + new_mask_inputs, num_heads); |
| 156 | + |
| 157 | + std::move(mask_split_op.begin(), mask_split_op.end(), |
| 158 | + std::back_inserter(new_ops)); |
| 159 | + |
| 160 | + std::vector<TensorWrapperRef> sha_outputs; |
| 161 | + sha_outputs.reserve(num_heads); |
| 162 | + for (size_t i = 0; i < num_heads; ++i) { |
| 163 | + sha_outputs.emplace_back(BuildSingleSHA( |
| 164 | + new_ops, tensor_pool, const_cast<TensorWrapper&>(sha_inputs[i].get()), |
| 165 | + const_cast<TensorWrapper&>(new_mask_inputs[i].get()), mul_op, |
| 166 | + matmul_op1, add_op, softmax_op, matmul_op2, num_heads)); |
| 167 | + } |
| 168 | + |
| 169 | + // Concat |
| 170 | + auto concat_dims = pattern_output.GetDims(); |
| 171 | + concat_dims.insert(concat_dims.begin(), 1); |
| 172 | + auto& concat_output = |
| 173 | + tensor_pool.CloneNativeTensorFrom(pattern_output, concat_dims); |
| 174 | + auto concat_final = |
| 175 | + BuildConcatenationOp(tensor_pool, sha_outputs, {concat_output}, 3); |
| 176 | + std::move(concat_final.begin(), concat_final.end(), |
| 177 | + std::back_inserter(new_ops)); |
| 178 | + // Reshape |
| 179 | + auto reshape = |
| 180 | + BuildReshapeOp(tensor_pool, {concat_output}, |
| 181 | + {const_cast<::qnn::TensorWrapper&>(pattern_output)}); |
| 182 | + std::move(reshape.begin(), reshape.end(), std::back_inserter(new_ops)); |
| 183 | + return new_ops; |
| 184 | +} |
| 185 | + |
| 186 | +size_t TransformEmbeddingGemma( |
| 187 | + std::function<bool(OpWrapper&)> validate_op_config, |
| 188 | + std::vector<OpWrapper>& ops, size_t start_index, TensorPool& tensor_pool, |
| 189 | + size_t pattern_size) { |
| 190 | + // Connection check |
| 191 | + auto is_op_connected = [](const OpWrapper& op1, |
| 192 | + const OpWrapper& op2) -> bool { |
| 193 | + return op1.GetOutputTensor(0) == op2.GetInputTensor(0); |
| 194 | + }; |
| 195 | + |
| 196 | + const auto& mul_op = ops[start_index + kMulIndexIndex]; |
| 197 | + const auto& tranpose_op1 = ops[start_index + kTransposeIndex]; |
| 198 | + const auto& reshape_op1 = ops[start_index + kReshapeIndex]; |
| 199 | + const auto& matmul_op1 = ops[start_index + kMatMulIndex]; |
| 200 | + const auto& add_op = ops[start_index + kAddIndex]; |
| 201 | + const auto& softmax_op = ops[start_index + kSoftmaxIndex]; |
| 202 | + const auto& matmul_op2 = ops[start_index + kMatMul2Index]; |
| 203 | + const auto& reshape_op2 = ops[start_index + kReshape2Index]; |
| 204 | + const auto& transpose_op2 = ops[start_index + kTranspose2Index]; |
| 205 | + const auto& reshape_op3 = ops[start_index + kReshape3Index]; |
| 206 | + |
| 207 | + bool is_match = is_op_connected(mul_op, tranpose_op1) && |
| 208 | + is_op_connected(tranpose_op1, reshape_op1) && |
| 209 | + is_op_connected(reshape_op1, matmul_op1) && |
| 210 | + is_op_connected(matmul_op1, add_op) && |
| 211 | + is_op_connected(add_op, softmax_op) && |
| 212 | + is_op_connected(softmax_op, matmul_op2) && |
| 213 | + is_op_connected(matmul_op2, reshape_op2) && |
| 214 | + is_op_connected(reshape_op2, transpose_op2) && |
| 215 | + is_op_connected(transpose_op2, reshape_op3); |
| 216 | + if (!is_match) { |
| 217 | + return 1; |
| 218 | + } |
| 219 | + // Graph transform |
| 220 | + QNN_LOG_INFO("[G2G] Transforming MHA to SHA in Embedding Gemma"); |
| 221 | + // Construct the new subgraph |
| 222 | + const auto& pattern_input = mul_op.GetInputTensor(0); |
| 223 | + const auto& pattern_output = reshape_op3.GetOutputTensor(0); |
| 224 | + auto new_ops = MHA2SHA(tensor_pool, mul_op, tranpose_op1, matmul_op1, add_op, |
| 225 | + softmax_op, matmul_op2, pattern_input, pattern_output); |
| 226 | + if (new_ops.empty()) { |
| 227 | + QNN_LOG_WARNING( |
| 228 | + "[G2G] Transformation failed. Rolling back to the original graph."); |
| 229 | + return 1; |
| 230 | + } |
| 231 | + // Validate new graph. |
| 232 | + bool is_valid = |
| 233 | + std::all_of(new_ops.begin(), new_ops.end(), |
| 234 | + [validate_op_config](::qnn::OpWrapper& op_wrapper) -> bool { |
| 235 | + return validate_op_config(op_wrapper); |
| 236 | + }); |
| 237 | + if (is_valid) { |
| 238 | + // Adjust the name to avoid a name collision in the Qnn JSON dump. |
| 239 | + for (size_t i = 0; i < new_ops.size(); ++i) { |
| 240 | + new_ops[i].AddSuffixToName(absl::StrCat("_qcg2g_", i)); |
| 241 | + } |
| 242 | + // Replace the matched pattern with a newly generated subgraph. |
| 243 | + size_t step_size = new_ops.size(); |
| 244 | + ops.insert(ops.begin() + start_index + pattern_size, |
| 245 | + std::make_move_iterator(new_ops.begin()), |
| 246 | + std::make_move_iterator(new_ops.end())); |
| 247 | + ops.erase(ops.begin() + start_index, |
| 248 | + ops.begin() + start_index + pattern_size); |
| 249 | + QNN_LOG_INFO("[G2G] Done transforming MHA to SHA in Embedding Gemma!"); |
| 250 | + return step_size; |
| 251 | + } |
| 252 | + QNN_LOG_WARNING( |
| 253 | + "[G2G] Validation failed. Rolling back to the original graph."); |
| 254 | + return 1; |
| 255 | +} |
| 256 | + |
| 257 | +} // namespace qnn |
0 commit comments