@@ -1042,10 +1042,11 @@ TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) {
10421042// cast_input -> Cast -> Q -> DQ ----
10431043// |
10441044// input2 -> Q -> DQ -> Add -> Q -> DQ -> output
1045- static GetTestModelFn BuildCastAddTestCase () {
1046- return [](ModelTestBuilder& builder) {
1045+ template <typename InputType, typename QuantType>
1046+ static GetTestQDQModelFn<QuantType> BuildCastAddQDQTestCase () {
1047+ return [](ModelTestBuilder& builder, std::vector<QuantParams<QuantType>>& output_qparams) {
10471048 // Creat Cast node int32 -> float32
1048- NodeArg* cast_input = MakeTestInput (builder, TestInputDef<int32_t >({2 , 3 }, false , {0 , 1 , 0 , 1 , 0 , 1 }));
1049+ NodeArg* cast_input = MakeTestInput (builder, TestInputDef<InputType >({2 , 3 }, false , {0 , 1 , 0 , 1 , 0 , 1 }));
10491050
10501051 auto * cast_output = builder.MakeIntermediate ();
10511052 Node& cast_node = builder.AddNode (" Cast" , {cast_input}, {cast_output});
@@ -1054,18 +1055,36 @@ static GetTestModelFn BuildCastAddTestCase() {
10541055 // Create Add node
10551056 std::vector<float > data = {0 .0f , 0 .0f , 1 .0f , 0 .0f , 1 .0f , 0 .0f };
10561057 gsl::span<float > data_range = gsl::make_span (data);
1057- QuantParams<uint8_t > q_parameter = GetDataQuantParams<uint8_t >(data_range);
1058- auto * add_input1_qdq = AddQDQNodePair<uint8_t >(builder, cast_output, q_parameter.scale , q_parameter.zero_point );
1058+ QuantParams<QuantType > q_parameter = GetDataQuantParams<QuantType >(data_range);
1059+ auto * add_input1_qdq = AddQDQNodePair<QuantType >(builder, cast_output, q_parameter.scale , q_parameter.zero_point );
10591060
10601061 NodeArg* add_input2 = MakeTestInput (builder, TestInputDef<float >({2 , 3 }, false , data));
1061- auto * add_input2_qdq = AddQDQNodePair<uint8_t >(builder, add_input2, q_parameter.scale , q_parameter.zero_point );
1062+ auto * add_input2_qdq = AddQDQNodePair<QuantType >(builder, add_input2, q_parameter.scale , q_parameter.zero_point );
10621063
10631064 auto * add_output = builder.MakeIntermediate ();
10641065
10651066 builder.AddNode (" Add" , {add_input1_qdq, add_input2_qdq}, {add_output});
10661067
10671068 // add_output -> Q -> DQ -> output
1068- AddQDQNodePairWithOutputAsGraphOutput<uint8_t >(builder, add_output, q_parameter.scale , q_parameter.zero_point );
1069+ AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, add_output, output_qparams[0 ].scale , output_qparams[0 ].zero_point );
1070+ };
1071+ }
1072+
1073+ template <typename InputType>
1074+ static GetTestModelFn BuildCastAddTestCase () {
1075+ return [](ModelTestBuilder& builder) {
1076+ // Creat Cast node int32 -> float32
1077+ NodeArg* cast_input = MakeTestInput (builder, TestInputDef<InputType>({2 , 3 }, false , {0 , 1 , 0 , 1 , 0 , 1 }));
1078+
1079+ auto * cast_output = builder.MakeIntermediate ();
1080+ Node& cast_node = builder.AddNode (" Cast" , {cast_input}, {cast_output});
1081+ cast_node.AddAttribute (" to" , static_cast <int64_t >(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT));
1082+
1083+ // Create Add node
1084+ NodeArg* add_input2 = MakeTestInput (builder, TestInputDef<float >({2 , 3 }, false , {0 .0f , 0 .0f , 1 .0f , 0 .0f , 1 .0f , 0 .0f }));
1085+ auto * add_output = builder.MakeOutput ();
1086+
1087+ builder.AddNode (" Add" , {cast_output, add_input2}, {add_output});
10691088 };
10701089}
10711090
@@ -1091,19 +1110,53 @@ TEST_F(QnnHTPBackendTests, ProfilingTest) {
10911110 0 .008f );
10921111}
10931112
1094- TEST_F (QnnHTPBackendTests, CastAddHTPAccuracyTest ) {
1113+ TEST_F (QnnHTPBackendTests, CastAddQDQU8 ) {
10951114 ProviderOptions provider_options;
1096- #if defined(_WIN32)
1097- provider_options[" backend_path" ] = " QnnHtp.dll" ;
1098- #else
1099- provider_options[" backend_path" ] = " libQnnHtp.so" ;
1100- #endif
1115+ provider_options[" backend_type" ] = " htp" ;
11011116 provider_options[" offload_graph_io_quantization" ] = " 0" ;
11021117
1103- RunQnnModelTest (BuildCastAddTestCase (),
1104- provider_options,
1105- 13 , // opset
1106- ExpectedEPNodeAssignment::All);
1118+ TestQDQModelAccuracy<uint8_t >(BuildCastAddTestCase<uint8_t >(),
1119+ BuildCastAddQDQTestCase<uint8_t , uint8_t >(),
1120+ provider_options,
1121+ 21 ,
1122+ ExpectedEPNodeAssignment::All);
1123+ }
1124+
1125+ TEST_F (QnnHTPBackendTests, CastAddQDQU16) {
1126+ ProviderOptions provider_options;
1127+ provider_options[" backend_type" ] = " htp" ;
1128+ provider_options[" offload_graph_io_quantization" ] = " 0" ;
1129+
1130+ TestQDQModelAccuracy<uint16_t >(BuildCastAddTestCase<uint8_t >(),
1131+ BuildCastAddQDQTestCase<uint8_t , uint16_t >(),
1132+ provider_options,
1133+ 21 ,
1134+ ExpectedEPNodeAssignment::All);
1135+ }
1136+
1137+ TEST_F (QnnHTPBackendTests, CastAddQDQS8) {
1138+ ProviderOptions provider_options;
1139+ provider_options[" backend_type" ] = " htp" ;
1140+ provider_options[" offload_graph_io_quantization" ] = " 0" ;
1141+
1142+ TestQDQModelAccuracy<int8_t >(BuildCastAddTestCase<uint8_t >(),
1143+ BuildCastAddQDQTestCase<uint8_t , int8_t >(),
1144+ provider_options,
1145+ 21 ,
1146+ ExpectedEPNodeAssignment::All);
1147+ }
1148+
1149+ TEST_F (QnnHTPBackendTests, CastAddQDQS16) {
1150+ ProviderOptions provider_options;
1151+ provider_options[" backend_type" ] = " htp" ;
1152+ provider_options[" offload_graph_io_quantization" ] = " 0" ;
1153+
1154+ TestQDQModelAccuracy<int16_t >(BuildCastAddTestCase<uint8_t >(),
1155+ BuildCastAddQDQTestCase<uint8_t , int16_t >(),
1156+ provider_options,
1157+ 21 ,
1158+ // QNN has not yet supported S16 Quantize/Dequantize
1159+ ExpectedEPNodeAssignment::Some);
11071160}
11081161
11091162// Test float32 model with FP16 precision
0 commit comments