@@ -785,5 +785,301 @@ TEST(MaskTransformTest, Gemma3) {
785785 ASSERT_EQ (quant_param_2.GetZeroPoint (), mul_zero_point);
786786}
787787
788+ TEST (MHASHATest, FastVlm) {
789+ // G2G Test case: MHA -> SHA
790+
791+ // ------------------- Before ---------------------
792+ // In0
793+ // |
794+ // Mul
795+ // |
796+ // Reshape0 In1 In2
797+ // / \ \ /
798+ // QIn / \ Add0
799+ // \ / \ /
800+ // Matmul0 Matmul1
801+ // \ /
802+ // Concat
803+ // |
804+ // Reshape1
805+ // |
806+ // Mask |
807+ // \ |
808+ // Add1
809+ // |
810+ // Reshape2
811+ // |
812+ // Softmax
813+ // / \
814+ // Slice0 Slice1 In3
815+ // | | |
816+ // VIn | | Transpose0
817+ // \ | | /
818+ // Matmul2 Matmul3
819+ // \ /
820+ // Add2
821+ // |
822+ // Reshape3
823+ // |
824+ // Transpose1
825+ // |
826+ // Reshape4
827+ // |
828+ // Out
829+ //
830+ // -------------------- After ---------------------
831+ // In0
832+ // |
833+ // Mul In1 In2
834+ // / \ \ /
835+ // QIn / \ Add0
836+ // \ / \ /
837+ // Matmul0 Matmul1
838+ // \ /
839+ // Concat
840+ // |
841+ // Mask |
842+ // \ |
843+ // Add1
844+ // |
845+ // Reshape
846+ // |
847+ // Softmax
848+ // / \
849+ // Slice0 Slice1
850+ // | |
851+ // VIn | | In3
852+ // \ | | /
853+ // Matmul2 Matmul3
854+ // \ /
855+ // Add2
856+ // |
857+ // Out
858+ TensorPool tensor_pool;
859+ QuantizeParamsWrapperVariant quant_param;
860+ quant_param.emplace <ScaleOffsetQuantizeParamsWrapper>(1e-4f , 0 );
861+ std::vector<OpWrapper> op_wrappers;
862+
863+ // Add0
864+ auto & input1 = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
865+ quant_param, {1 , 2 , 128 , 64 });
866+ auto & input2 = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
867+ quant_param, {1 , 2 , 128 , 64 });
868+ auto & add0_output =
869+ tensor_pool.CloneNativeTensorFrom (input1, {1 , 2 , 128 , 64 });
870+ auto add0 =
871+ BuildElementwiseAddOp (tensor_pool, {input1, input2}, {add0_output});
872+ std::move (add0.begin (), add0.end (), std::back_inserter (op_wrappers));
873+
874+ // Transpose0
875+ auto & input3 = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
876+ quant_param, {1 , 128 , 2 , 64 });
877+ std::array<int32_t , 4 > transpose0_val = {0 , 2 , 3 , 1 };
878+ auto & transpose0_perm = tensor_pool.CreateStaticTensor (
879+ QNN_DATATYPE_INT_32, quant_param, {transpose0_val.size ()},
880+ transpose0_val.size () * sizeof (transpose0_val[0 ]), transpose0_val.data ());
881+ auto & transpose0_output =
882+ tensor_pool.CloneNativeTensorFrom (add0_output, {1 , 2 , 64 , 128 });
883+ auto transpose0 = BuildTransposeOp (tensor_pool, {input3, transpose0_perm},
884+ {transpose0_output});
885+ std::move (transpose0.begin (), transpose0.end (),
886+ std::back_inserter (op_wrappers));
887+
888+ // Mul
889+ auto & input0 = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
890+ quant_param, {1 , 14 , 128 , 64 });
891+ std::array<int16_t , 1 > mul_val = {32767 };
892+ auto & mul_const = tensor_pool.CreateStaticTensor (
893+ QNN_DATATYPE_SFIXED_POINT_16, quant_param, {mul_val.size ()},
894+ mul_val.size () * sizeof (mul_val[0 ]), mul_val.data ());
895+ auto & mul_output =
896+ tensor_pool.CloneNativeTensorFrom (input0, {1 , 14 , 128 , 64 });
897+ auto mul =
898+ BuildElementwiseMulOp (tensor_pool, {input0, mul_const}, {mul_output});
899+ std::move (mul.begin (), mul.end (), std::back_inserter (op_wrappers));
900+
901+ // Reshape0
902+ auto & reshape0_output =
903+ tensor_pool.CloneNativeTensorFrom (mul_output, {1 , 2 , 896 , 64 });
904+ auto reshape0 = BuildReshapeOp (tensor_pool, {mul_output}, {reshape0_output});
905+ std::move (reshape0.begin (), reshape0.end (), std::back_inserter (op_wrappers));
906+
907+ // MatMul0
908+ auto & q_in = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
909+ quant_param, {1 , 2 , 1280 , 64 });
910+ auto & matmul0_output = tensor_pool.CreateNativeTensor (
911+ QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1 , 2 , 896 , 1280 });
912+ auto matmul0 = BuildMatmulOp (tensor_pool, {reshape0_output, q_in},
913+ {matmul0_output}, false , true );
914+ std::move (matmul0.begin (), matmul0.end (), std::back_inserter (op_wrappers));
915+
916+ // MatMul1
917+ auto & matmul1_output = tensor_pool.CreateNativeTensor (
918+ QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1 , 2 , 896 , 128 });
919+ auto matmul1 = BuildMatmulOp (tensor_pool, {reshape0_output, add0_output},
920+ {matmul1_output}, false , true );
921+ std::move (matmul1.begin (), matmul1.end (), std::back_inserter (op_wrappers));
922+
923+ // Concat
924+ auto & concat_output =
925+ tensor_pool.CloneNativeTensorFrom (matmul0_output, {1 , 2 , 896 , 1408 });
926+ auto concat = BuildConcatenationOp (
927+ tensor_pool, {matmul0_output, matmul1_output}, {concat_output}, 3 );
928+ std::move (concat.begin (), concat.end (), std::back_inserter (op_wrappers));
929+
930+ // Reshape1
931+ auto & reshape1_output =
932+ tensor_pool.CloneNativeTensorFrom (concat_output, {2 , 7 , 128 , 1408 });
933+ auto reshape1 =
934+ BuildReshapeOp (tensor_pool, {concat_output}, {reshape1_output});
935+ std::move (reshape1.begin (), reshape1.end (), std::back_inserter (op_wrappers));
936+
937+ // Add1
938+ auto & mask = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
939+ quant_param, {1 , 1 , 128 , 1408 });
940+ auto & add1_output = tensor_pool.CloneNativeTensorFrom (reshape1_output);
941+ auto add1 = BuildElementwiseAddOp (tensor_pool, {reshape1_output, mask},
942+ {add1_output});
943+ std::move (add1.begin (), add1.end (), std::back_inserter (op_wrappers));
944+
945+ // Reshape2
946+ auto & reshape2_output =
947+ tensor_pool.CloneNativeTensorFrom (add1_output, {1 , 2 , 896 , 1408 });
948+ auto reshape2 = BuildReshapeOp (tensor_pool, {add1_output}, {reshape2_output});
949+ std::move (reshape2.begin (), reshape2.end (), std::back_inserter (op_wrappers));
950+
951+ // Softmax
952+ auto & softmax_output = tensor_pool.CloneNativeTensorFrom (reshape2_output);
953+ auto softmax =
954+ BuildSoftmaxOp (tensor_pool, {reshape2_output}, {softmax_output}, 1 .0f );
955+ std::move (softmax.begin (), softmax.end (), std::back_inserter (op_wrappers));
956+
957+ // Slice0
958+ const std::array<int32_t , 4 > slice0_begin_data{0 , 0 , 0 , 0 };
959+ auto & slice0_begin = tensor_pool.CreateStaticTensor (
960+ QNN_DATATYPE_INT_32, {}, {slice0_begin_data.size ()},
961+ slice0_begin_data.size () * sizeof (slice0_begin_data[0 ]),
962+ slice0_begin_data.data ());
963+ const std::array<int32_t , 4 > slice0_size_data{1 , 2 , 896 , 1280 };
964+ auto & slice0_size = tensor_pool.CreateStaticTensor (
965+ QNN_DATATYPE_INT_32, {}, {slice0_size_data.size ()},
966+ slice0_size_data.size () * sizeof (slice0_size_data[0 ]),
967+ slice0_size_data.data ());
968+ auto & slice0_output =
969+ tensor_pool.CloneNativeTensorFrom (softmax_output, {1 , 2 , 896 , 1280 });
970+ auto slice0 =
971+ BuildSliceOp (tensor_pool, {softmax_output, slice0_begin, slice0_size},
972+ {slice0_output});
973+ std::move (slice0.begin (), slice0.end (), std::back_inserter (op_wrappers));
974+
975+ // Slice1
976+ const std::array<int32_t , 4 > slice1_begin_data{0 , 0 , 0 , 1280 };
977+ auto & slice1_begin = tensor_pool.CreateStaticTensor (
978+ QNN_DATATYPE_INT_32, {}, {slice1_begin_data.size ()},
979+ slice1_begin_data.size () * sizeof (slice1_begin_data[0 ]),
980+ slice1_begin_data.data ());
981+ const std::array<int32_t , 4 > slice1_size_data{1 , 2 , 896 , 128 };
982+ auto & slice1_size = tensor_pool.CreateStaticTensor (
983+ QNN_DATATYPE_INT_32, {}, {slice1_size_data.size ()},
984+ slice1_size_data.size () * sizeof (slice1_size_data[0 ]),
985+ slice1_size_data.data ());
986+ auto & slice1_output =
987+ tensor_pool.CloneNativeTensorFrom (softmax_output, {1 , 2 , 896 , 128 });
988+ auto slice1 =
989+ BuildSliceOp (tensor_pool, {softmax_output, slice1_begin, slice1_size},
990+ {slice1_output});
991+ std::move (slice1.begin (), slice1.end (), std::back_inserter (op_wrappers));
992+
993+ // MatMul2
994+ auto & v_in = tensor_pool.CreateNativeTensor (QNN_DATATYPE_SFIXED_POINT_16,
995+ quant_param, {1 , 2 , 64 , 1280 });
996+ auto & matmul2_output = tensor_pool.CreateNativeTensor (
997+ QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1 , 2 , 896 , 64 });
998+ auto matmul2 = BuildMatmulOp (tensor_pool, {slice0_output, v_in},
999+ {matmul2_output}, false , true );
1000+ std::move (matmul2.begin (), matmul2.end (), std::back_inserter (op_wrappers));
1001+
1002+ // MatMul3
1003+ auto & matmul3_output = tensor_pool.CreateNativeTensor (
1004+ QNN_DATATYPE_SFIXED_POINT_16, quant_param, {1 , 2 , 896 , 64 });
1005+ auto matmul3 = BuildMatmulOp (tensor_pool, {slice1_output, transpose0_output},
1006+ {matmul3_output}, false , true );
1007+ std::move (matmul3.begin (), matmul3.end (), std::back_inserter (op_wrappers));
1008+
1009+ // Add2
1010+ auto & add2_output = tensor_pool.CloneNativeTensorFrom (matmul3_output);
1011+ auto add2 = BuildElementwiseAddOp (
1012+ tensor_pool, {matmul2_output, matmul3_output}, {add2_output});
1013+ std::move (add2.begin (), add2.end (), std::back_inserter (op_wrappers));
1014+
1015+ // Reshape3
1016+ auto & reshape3_output =
1017+ tensor_pool.CloneNativeTensorFrom (add2_output, {1 , 14 , 128 , 64 });
1018+ auto reshape3 = BuildReshapeOp (tensor_pool, {add2_output}, {reshape3_output});
1019+ std::move (reshape3.begin (), reshape3.end (), std::back_inserter (op_wrappers));
1020+
1021+ // Transpose1
1022+ std::array<int32_t , 4 > transpose1_val = {0 , 2 , 1 , 3 };
1023+ auto & transpose1_perm = tensor_pool.CreateStaticTensor (
1024+ QNN_DATATYPE_INT_32, quant_param, {transpose1_val.size ()},
1025+ transpose1_val.size () * sizeof (transpose1_val[0 ]), transpose1_val.data ());
1026+ auto & transpose1_output =
1027+ tensor_pool.CloneNativeTensorFrom (reshape3_output, {1 , 128 , 14 , 64 });
1028+ auto transpose1 = BuildTransposeOp (
1029+ tensor_pool, {reshape3_output, transpose1_perm}, {transpose1_output});
1030+ std::move (transpose1.begin (), transpose1.end (),
1031+ std::back_inserter (op_wrappers));
1032+
1033+ // Reshape4
1034+ auto & reshape4_output =
1035+ tensor_pool.CloneNativeTensorFrom (transpose1_output, {1 , 128 , 896 });
1036+ auto reshape4 =
1037+ BuildReshapeOp (tensor_pool, {transpose1_output}, {reshape4_output});
1038+ std::move (reshape4.begin (), reshape4.end (), std::back_inserter (op_wrappers));
1039+
1040+ ASSERT_EQ (op_wrappers.size (), 19 );
1041+
1042+ const ::qnn::G2GConfig g2g_option = ::qnn::G2GConfig::kMHAOptPrefill ;
1043+ GraphToGraphTransform (g2g_option, op_wrappers, tensor_pool,
1044+ [](OpWrapper& op) { return true ; });
1045+ // Check total size after G2G
1046+ ASSERT_EQ (op_wrappers.size (), 191 );
1047+
1048+ // Check OpCode after G2G
1049+ const size_t num_unpack = 6 ;
1050+ const size_t num_head = 14 ;
1051+ const size_t sha_size = 13 ;
1052+
1053+ ASSERT_TRUE (op_wrappers[0 ].IsOpCode (QnnOpCode::kElementWiseAdd ));
1054+ ASSERT_TRUE (op_wrappers[1 ].IsOpCode (QnnOpCode::kTranspose ));
1055+
1056+ for (size_t i = 0 ; i < num_unpack; ++i) {
1057+ ASSERT_TRUE (op_wrappers[2 + i].IsOpCode (QnnOpCode::kUnPack ));
1058+ }
1059+
1060+ for (size_t i = 0 ; i < num_head; ++i) {
1061+ ASSERT_TRUE (op_wrappers[8 + sha_size * i].IsOpCode (
1062+ QnnOpCode::kElementWiseMultiply ));
1063+ ASSERT_TRUE (op_wrappers[9 + sha_size * i].IsOpCode (QnnOpCode::kMatMul ));
1064+ ASSERT_TRUE (
1065+ op_wrappers[10 + sha_size * i].IsOpCode (QnnOpCode::kElementWiseAdd ));
1066+ ASSERT_TRUE (op_wrappers[11 + sha_size * i].IsOpCode (QnnOpCode::kMatMul ));
1067+ ASSERT_TRUE (op_wrappers[12 + sha_size * i].IsOpCode (QnnOpCode::kConcat ));
1068+ ASSERT_TRUE (
1069+ op_wrappers[13 + sha_size * i].IsOpCode (QnnOpCode::kElementWiseAdd ));
1070+ ASSERT_TRUE (op_wrappers[14 + sha_size * i].IsOpCode (QnnOpCode::kReshape ));
1071+ ASSERT_TRUE (op_wrappers[15 + sha_size * i].IsOpCode (QnnOpCode::kSoftmax ));
1072+ ASSERT_TRUE (
1073+ op_wrappers[16 + sha_size * i].IsOpCode (QnnOpCode::kStridedSlice ));
1074+ ASSERT_TRUE (
1075+ op_wrappers[17 + sha_size * i].IsOpCode (QnnOpCode::kStridedSlice ));
1076+ ASSERT_TRUE (op_wrappers[18 + sha_size * i].IsOpCode (QnnOpCode::kMatMul ));
1077+ ASSERT_TRUE (op_wrappers[19 + sha_size * i].IsOpCode (QnnOpCode::kMatMul ));
1078+ ASSERT_TRUE (
1079+ op_wrappers[20 + sha_size * i].IsOpCode (QnnOpCode::kElementWiseAdd ));
1080+ }
1081+ ASSERT_TRUE (op_wrappers[op_wrappers.size () - 1 ].IsOpCode (QnnOpCode::kConcat ));
1082+ }
1083+
7881084} // namespace
7891085} // namespace qnn
0 commit comments