Skip to content

Commit ccc129d

Browse files
committed
Qualcomm AI Engine Direct - Optimize fast vlm
Summary: - Add MHA to SHA optimization - Add test for Fast vlm with G2G transformation
1 parent d05c94f commit ccc129d

File tree

5 files changed

+710
-2
lines changed

5 files changed

+710
-2
lines changed

litert/vendors/qualcomm/core/transformation/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ cc_library(
8787
"//litert/vendors/qualcomm/core/builders:concatenation_op_builder",
8888
"//litert/vendors/qualcomm/core/builders:reshape_op_builder",
8989
"//litert/vendors/qualcomm/core/builders:split_op_builder",
90+
"//litert/vendors/qualcomm/core/builders:unpack_op_builder",
9091
"//litert/vendors/qualcomm/core/utils:log",
9192
"//litert/vendors/qualcomm/core/wrappers:op_wrapper",
9293
"//litert/vendors/qualcomm/core/wrappers:tensor_wrapper",

litert/vendors/qualcomm/core/transformation/graph_to_graph.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ void GraphToGraphTransform(const G2GConfig g2g_option,
156156
Transform(validate_op_config, ops, tensor_pool, gemma3_mha_prefill,
157157
OptimizeMHAPrefill);
158158
}
159+
160+
// Mask Gemma Optimization
159161
const std::vector<QnnOpCode> gemma3_mask = {
160162
QnnOpCode::kElementWiseNot,
161163
QnnOpCode::kCast,
@@ -165,6 +167,7 @@ void GraphToGraphTransform(const G2GConfig g2g_option,
165167
Transform(validate_op_config, ops, tensor_pool, gemma3_mask,
166168
TransformQuantizeInMask);
167169

170+
// Embedding Gemma Optimization
168171
const std::vector<QnnOpCode> embedding_gemma = {
169172
QnnOpCode::kElementWiseMultiply,
170173
QnnOpCode::kTranspose,
@@ -179,5 +182,27 @@ void GraphToGraphTransform(const G2GConfig g2g_option,
179182
};
180183
Transform(validate_op_config, ops, tensor_pool, embedding_gemma,
181184
TransformEmbeddingGemma);
185+
186+
// Fast Vlm Optimization
187+
const std::vector<QnnOpCode> fast_vlm_mha_prefill = {
188+
QnnOpCode::kElementWiseMultiply,
189+
QnnOpCode::kReshape,
190+
QnnOpCode::kMatMul,
191+
QnnOpCode::kMatMul,
192+
QnnOpCode::kConcat,
193+
QnnOpCode::kReshape,
194+
QnnOpCode::kElementWiseAdd,
195+
QnnOpCode::kReshape,
196+
QnnOpCode::kSoftmax,
197+
QnnOpCode::kStridedSlice,
198+
QnnOpCode::kStridedSlice,
199+
QnnOpCode::kMatMul,
200+
QnnOpCode::kMatMul,
201+
QnnOpCode::kElementWiseAdd,
202+
QnnOpCode::kReshape,
203+
QnnOpCode::kTranspose,
204+
QnnOpCode::kReshape};
205+
Transform(validate_op_config, ops, tensor_pool, fast_vlm_mha_prefill,
206+
OptimizeMHAFastVlmPrefill);
182207
}
183208
} // namespace qnn

litert/vendors/qualcomm/core/transformation/graph_to_graph_test.cc

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,5 +785,301 @@ TEST(MaskTransformTest, Gemma3) {
785785
ASSERT_EQ(quant_param_2.GetZeroPoint(), mul_zero_point);
786786
}
787787

788+
TEST(MHASHATest, FastVlm) {
789+
// G2G Test case: MHA -> SHA
790+
791+
// ------------------- Before ---------------------
792+
// In0
793+
// |
794+
// Mul
795+
// |
796+
// Reshape0 In1 In2
797+
// / \ \ /
798+
// QIn / \ Add0
799+
// \ / \ /
800+
// Matmul0 Matmul1
801+
// \ /
802+
// Concat
803+
// |
804+
// Reshape1
805+
// |
806+
// Mask |
807+
// \ |
808+
// Add1
809+
// |
810+
// Reshape2
811+
// |
812+
// Softmax
813+
// / \
814+
// Slice0 Slice1 In3
815+
// | | |
816+
// VIn | | Transpose0
817+
// \ | | /
818+
// Matmul2 Matmul3
819+
// \ /
820+
// Add2
821+
// |
822+
// Reshape3
823+
// |
824+
// Transpose1
825+
// |
826+
// Reshape4
827+
// |
828+
// Out
829+
//
830+
// -------------------- After ---------------------
831+
// In0
832+
// |
833+
// Mul In1 In2
834+
// / \ \ /
835+
// QIn / \ Add0
836+
// \ / \ /
837+
// Matmul0 Matmul1
838+
// \ /
839+
// Concat
840+
// |
841+
// Mask |
842+
// \ |
843+
// Add1
844+
// |
845+
// Reshape
846+
// |
847+
// Softmax
848+
// / \
849+
// Slice0 Slice1
850+
// | |
851+
// VIn | | In3
852+
// \ | | /
853+
// Matmul2 Matmul3
854+
// \ /
855+
// Add2
856+
// |
857+
// Out
858+
TensorPool tensor_pool;
859+
QuantizeParamsWrapperVariant quant_param;
860+
quant_param.emplace<ScaleOffsetQuantizeParamsWrapper>(1e-4f, 0);
861+
std::vector<OpWrapper> op_wrappers;
862+
863+
// Add0
864+
auto& input1 = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
865+
quant_param, {1, 2, 128, 64});
866+
auto& input2 = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
867+
quant_param, {1, 2, 128, 64});
868+
auto& add0_output =
869+
tensor_pool.CloneNativeTensorFrom(input1, {1, 2, 128, 64});
870+
auto add0 =
871+
BuildElementwiseAddOp(tensor_pool, {input1, input2}, {add0_output});
872+
std::move(add0.begin(), add0.end(), std::back_inserter(op_wrappers));
873+
874+
// Transpose0
875+
auto& input3 = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
876+
quant_param, {1, 128, 2, 64});
877+
std::array<int32_t, 4> transpose0_val = {0, 2, 3, 1};
878+
auto& transpose0_perm = tensor_pool.CreateStaticTensor(
879+
QNN_DATATYPE_INT_32, quant_param, {transpose0_val.size()},
880+
transpose0_val.size() * sizeof(transpose0_val[0]), transpose0_val.data());
881+
auto& transpose0_output =
882+
tensor_pool.CloneNativeTensorFrom(add0_output, {1, 2, 64, 128});
883+
auto transpose0 = BuildTransposeOp(tensor_pool, {input3, transpose0_perm},
884+
{transpose0_output});
885+
std::move(transpose0.begin(), transpose0.end(),
886+
std::back_inserter(op_wrappers));
887+
888+
// Mul
889+
auto& input0 = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
890+
quant_param, {1, 14, 128, 64});
891+
std::array<int16_t, 1> mul_val = {32767};
892+
auto& mul_const = tensor_pool.CreateStaticTensor(
893+
QNN_DATATYPE_SFIXED_POINT_16, quant_param, {mul_val.size()},
894+
mul_val.size() * sizeof(mul_val[0]), mul_val.data());
895+
auto& mul_output =
896+
tensor_pool.CloneNativeTensorFrom(input0, {1, 14, 128, 64});
897+
auto mul =
898+
BuildElementwiseMulOp(tensor_pool, {input0, mul_const}, {mul_output});
899+
std::move(mul.begin(), mul.end(), std::back_inserter(op_wrappers));
900+
901+
// Reshape0
902+
auto& reshape0_output =
903+
tensor_pool.CloneNativeTensorFrom(mul_output, {1, 2, 896, 64});
904+
auto reshape0 = BuildReshapeOp(tensor_pool, {mul_output}, {reshape0_output});
905+
std::move(reshape0.begin(), reshape0.end(), std::back_inserter(op_wrappers));
906+
907+
// MatMul0
908+
auto& q_in = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
909+
quant_param, {1, 2, 1280, 64});
910+
auto& matmul0_output = tensor_pool.CreateNativeTensor(
911+
QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1, 2, 896, 1280});
912+
auto matmul0 = BuildMatmulOp(tensor_pool, {reshape0_output, q_in},
913+
{matmul0_output}, false, true);
914+
std::move(matmul0.begin(), matmul0.end(), std::back_inserter(op_wrappers));
915+
916+
// MatMul1
917+
auto& matmul1_output = tensor_pool.CreateNativeTensor(
918+
QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1, 2, 896, 128});
919+
auto matmul1 = BuildMatmulOp(tensor_pool, {reshape0_output, add0_output},
920+
{matmul1_output}, false, true);
921+
std::move(matmul1.begin(), matmul1.end(), std::back_inserter(op_wrappers));
922+
923+
// Concat
924+
auto& concat_output =
925+
tensor_pool.CloneNativeTensorFrom(matmul0_output, {1, 2, 896, 1408});
926+
auto concat = BuildConcatenationOp(
927+
tensor_pool, {matmul0_output, matmul1_output}, {concat_output}, 3);
928+
std::move(concat.begin(), concat.end(), std::back_inserter(op_wrappers));
929+
930+
// Reshape1
931+
auto& reshape1_output =
932+
tensor_pool.CloneNativeTensorFrom(concat_output, {2, 7, 128, 1408});
933+
auto reshape1 =
934+
BuildReshapeOp(tensor_pool, {concat_output}, {reshape1_output});
935+
std::move(reshape1.begin(), reshape1.end(), std::back_inserter(op_wrappers));
936+
937+
// Add1
938+
auto& mask = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
939+
quant_param, {1, 1, 128, 1408});
940+
auto& add1_output = tensor_pool.CloneNativeTensorFrom(reshape1_output);
941+
auto add1 = BuildElementwiseAddOp(tensor_pool, {reshape1_output, mask},
942+
{add1_output});
943+
std::move(add1.begin(), add1.end(), std::back_inserter(op_wrappers));
944+
945+
// Reshape2
946+
auto& reshape2_output =
947+
tensor_pool.CloneNativeTensorFrom(add1_output, {1, 2, 896, 1408});
948+
auto reshape2 = BuildReshapeOp(tensor_pool, {add1_output}, {reshape2_output});
949+
std::move(reshape2.begin(), reshape2.end(), std::back_inserter(op_wrappers));
950+
951+
// Softmax
952+
auto& softmax_output = tensor_pool.CloneNativeTensorFrom(reshape2_output);
953+
auto softmax =
954+
BuildSoftmaxOp(tensor_pool, {reshape2_output}, {softmax_output}, 1.0f);
955+
std::move(softmax.begin(), softmax.end(), std::back_inserter(op_wrappers));
956+
957+
// Slice0
958+
const std::array<int32_t, 4> slice0_begin_data{0, 0, 0, 0};
959+
auto& slice0_begin = tensor_pool.CreateStaticTensor(
960+
QNN_DATATYPE_INT_32, {}, {slice0_begin_data.size()},
961+
slice0_begin_data.size() * sizeof(slice0_begin_data[0]),
962+
slice0_begin_data.data());
963+
const std::array<int32_t, 4> slice0_size_data{1, 2, 896, 1280};
964+
auto& slice0_size = tensor_pool.CreateStaticTensor(
965+
QNN_DATATYPE_INT_32, {}, {slice0_size_data.size()},
966+
slice0_size_data.size() * sizeof(slice0_size_data[0]),
967+
slice0_size_data.data());
968+
auto& slice0_output =
969+
tensor_pool.CloneNativeTensorFrom(softmax_output, {1, 2, 896, 1280});
970+
auto slice0 =
971+
BuildSliceOp(tensor_pool, {softmax_output, slice0_begin, slice0_size},
972+
{slice0_output});
973+
std::move(slice0.begin(), slice0.end(), std::back_inserter(op_wrappers));
974+
975+
// Slice1
976+
const std::array<int32_t, 4> slice1_begin_data{0, 0, 0, 1280};
977+
auto& slice1_begin = tensor_pool.CreateStaticTensor(
978+
QNN_DATATYPE_INT_32, {}, {slice1_begin_data.size()},
979+
slice1_begin_data.size() * sizeof(slice1_begin_data[0]),
980+
slice1_begin_data.data());
981+
const std::array<int32_t, 4> slice1_size_data{1, 2, 896, 128};
982+
auto& slice1_size = tensor_pool.CreateStaticTensor(
983+
QNN_DATATYPE_INT_32, {}, {slice1_size_data.size()},
984+
slice1_size_data.size() * sizeof(slice1_size_data[0]),
985+
slice1_size_data.data());
986+
auto& slice1_output =
987+
tensor_pool.CloneNativeTensorFrom(softmax_output, {1, 2, 896, 128});
988+
auto slice1 =
989+
BuildSliceOp(tensor_pool, {softmax_output, slice1_begin, slice1_size},
990+
{slice1_output});
991+
std::move(slice1.begin(), slice1.end(), std::back_inserter(op_wrappers));
992+
993+
// MatMul2
994+
auto& v_in = tensor_pool.CreateNativeTensor(QNN_DATATYPE_SFIXED_POINT_16,
995+
quant_param, {1, 2, 64, 1280});
996+
auto& matmul2_output = tensor_pool.CreateNativeTensor(
997+
QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1, 2, 896, 64});
998+
auto matmul2 = BuildMatmulOp(tensor_pool, {slice0_output, v_in},
999+
{matmul2_output}, false, true);
1000+
std::move(matmul2.begin(), matmul2.end(), std::back_inserter(op_wrappers));
1001+
1002+
// MatMul3
1003+
auto& matmul3_output = tensor_pool.CreateNativeTensor(
1004+
QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1, 2, 896, 64});
1005+
auto matmul3 = BuildMatmulOp(tensor_pool, {slice1_output, transpose0_output},
1006+
{matmul3_output}, false, true);
1007+
std::move(matmul3.begin(), matmul3.end(), std::back_inserter(op_wrappers));
1008+
1009+
// Add2
1010+
auto& add2_output = tensor_pool.CloneNativeTensorFrom(matmul3_output);
1011+
auto add2 = BuildElementwiseAddOp(
1012+
tensor_pool, {matmul2_output, matmul3_output}, {add2_output});
1013+
std::move(add2.begin(), add2.end(), std::back_inserter(op_wrappers));
1014+
1015+
// Reshape3
1016+
auto& reshape3_output =
1017+
tensor_pool.CloneNativeTensorFrom(add2_output, {1, 14, 128, 64});
1018+
auto reshape3 = BuildReshapeOp(tensor_pool, {add2_output}, {reshape3_output});
1019+
std::move(reshape3.begin(), reshape3.end(), std::back_inserter(op_wrappers));
1020+
1021+
// Transpose1
1022+
std::array<int32_t, 4> transpose1_val = {0, 2, 1, 3};
1023+
auto& transpose1_perm = tensor_pool.CreateStaticTensor(
1024+
QNN_DATATYPE_INT_32, quant_param, {transpose1_val.size()},
1025+
transpose1_val.size() * sizeof(transpose1_val[0]), transpose1_val.data());
1026+
auto& transpose1_output =
1027+
tensor_pool.CloneNativeTensorFrom(reshape3_output, {1, 128, 14, 64});
1028+
auto transpose1 = BuildTransposeOp(
1029+
tensor_pool, {reshape3_output, transpose1_perm}, {transpose1_output});
1030+
std::move(transpose1.begin(), transpose1.end(),
1031+
std::back_inserter(op_wrappers));
1032+
1033+
// Reshape4
1034+
auto& reshape4_output =
1035+
tensor_pool.CloneNativeTensorFrom(transpose1_output, {1, 128, 896});
1036+
auto reshape4 =
1037+
BuildReshapeOp(tensor_pool, {transpose1_output}, {reshape4_output});
1038+
std::move(reshape4.begin(), reshape4.end(), std::back_inserter(op_wrappers));
1039+
1040+
ASSERT_EQ(op_wrappers.size(), 19);
1041+
1042+
const ::qnn::G2GConfig g2g_option = ::qnn::G2GConfig::kMHAOptPrefill;
1043+
GraphToGraphTransform(g2g_option, op_wrappers, tensor_pool,
1044+
[](OpWrapper& op) { return true; });
1045+
// Check total size after G2G
1046+
ASSERT_EQ(op_wrappers.size(), 191);
1047+
1048+
// Check OpCode after G2G
1049+
const size_t num_unpack = 6;
1050+
const size_t num_head = 14;
1051+
const size_t sha_size = 13;
1052+
1053+
ASSERT_TRUE(op_wrappers[0].IsOpCode(QnnOpCode::kElementWiseAdd));
1054+
ASSERT_TRUE(op_wrappers[1].IsOpCode(QnnOpCode::kTranspose));
1055+
1056+
for (size_t i = 0; i < num_unpack; ++i) {
1057+
ASSERT_TRUE(op_wrappers[2 + i].IsOpCode(QnnOpCode::kUnPack));
1058+
}
1059+
1060+
for (size_t i = 0; i < num_head; ++i) {
1061+
ASSERT_TRUE(op_wrappers[8 + sha_size * i].IsOpCode(
1062+
QnnOpCode::kElementWiseMultiply));
1063+
ASSERT_TRUE(op_wrappers[9 + sha_size * i].IsOpCode(QnnOpCode::kMatMul));
1064+
ASSERT_TRUE(
1065+
op_wrappers[10 + sha_size * i].IsOpCode(QnnOpCode::kElementWiseAdd));
1066+
ASSERT_TRUE(op_wrappers[11 + sha_size * i].IsOpCode(QnnOpCode::kMatMul));
1067+
ASSERT_TRUE(op_wrappers[12 + sha_size * i].IsOpCode(QnnOpCode::kConcat));
1068+
ASSERT_TRUE(
1069+
op_wrappers[13 + sha_size * i].IsOpCode(QnnOpCode::kElementWiseAdd));
1070+
ASSERT_TRUE(op_wrappers[14 + sha_size * i].IsOpCode(QnnOpCode::kReshape));
1071+
ASSERT_TRUE(op_wrappers[15 + sha_size * i].IsOpCode(QnnOpCode::kSoftmax));
1072+
ASSERT_TRUE(
1073+
op_wrappers[16 + sha_size * i].IsOpCode(QnnOpCode::kStridedSlice));
1074+
ASSERT_TRUE(
1075+
op_wrappers[17 + sha_size * i].IsOpCode(QnnOpCode::kStridedSlice));
1076+
ASSERT_TRUE(op_wrappers[18 + sha_size * i].IsOpCode(QnnOpCode::kMatMul));
1077+
ASSERT_TRUE(op_wrappers[19 + sha_size * i].IsOpCode(QnnOpCode::kMatMul));
1078+
ASSERT_TRUE(
1079+
op_wrappers[20 + sha_size * i].IsOpCode(QnnOpCode::kElementWiseAdd));
1080+
}
1081+
ASSERT_TRUE(op_wrappers[op_wrappers.size() - 1].IsOpCode(QnnOpCode::kConcat));
1082+
}
1083+
7881084
} // namespace
7891085
} // namespace qnn

0 commit comments

Comments
 (0)