@@ -5669,6 +5669,52 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) {
56695669 test_case (true );
56705670}
56715671
5672+ TEST (QDQTransformerTests, WeightBiasQuantization_Gemm_HandleNegativeDqAxis) {
5673+ auto test_case = [](bool use_contrib_qdq) {
5674+ auto build_test_case = [&](ModelTestBuilder& builder) {
5675+ NodeArg* input_arg =
5676+ builder.MakeInput <uint8_t >({2 , 16 }, std::numeric_limits<uint8_t >::min (), std::numeric_limits<uint8_t >::max ());
5677+ NodeArg* weight_arg = builder.MakeInitializer <uint8_t >({16 , 16 }, std::numeric_limits<uint8_t >::min (),
5678+ std::numeric_limits<uint8_t >::max ());
5679+ NodeArg* bias_arg = builder.MakeInitializer <float >({16 }, -0 .1f , 0 .1f );
5680+
5681+ NodeArg* input_dq_arg = builder.MakeIntermediate ();
5682+ NodeArg* weight_dq_arg = builder.MakeIntermediate ();
5683+ NodeArg* gemm_dq_arg = builder.MakeIntermediate ();
5684+ NodeArg* output_arg = builder.MakeOutput ();
5685+
5686+ builder.AddDequantizeLinearNode <uint8_t >(input_arg, 0 .001f , static_cast <uint8_t >(0 ), input_dq_arg, use_contrib_qdq);
5687+
5688+ // Per-channel quantized weight with negative axis as DQ attribute
5689+ std::vector<float > scales = std::vector<float >(16 , 0 .05f );
5690+ std::vector<uint8_t > zp = std::vector<uint8_t >(16 , static_cast <uint8_t >(0 ));
5691+ auto & dq_node = builder.AddDequantizeLinearNode <uint8_t >(weight_arg, scales, zp, weight_dq_arg, nullptr , use_contrib_qdq);
5692+ dq_node.AddAttribute (" axis" , static_cast <int64_t >(-1 ));
5693+
5694+ builder.AddNode (" Gemm" , {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg});
5695+ builder.AddQuantizeLinearNode <uint8_t >(gemm_dq_arg, 0 .144f , static_cast <uint8_t >(69 ), output_arg, use_contrib_qdq);
5696+ };
5697+
5698+ auto check_transformed_graph = [](InferenceSessionWrapper& session) {
5699+ auto op_to_count = CountOpsInGraph (session.GetGraph ());
5700+ EXPECT_EQ (op_to_count[" QuantizeLinear" ] + op_to_count[" com.microsoft.QuantizeLinear" ], 1 );
5701+ EXPECT_EQ (op_to_count[" DequantizeLinear" ] + op_to_count[" com.microsoft.DequantizeLinear" ], 2 + 1 );
5702+ };
5703+
5704+ TransformerTester (build_test_case,
5705+ check_transformed_graph,
5706+ TransformerLevel::Default,
5707+ TransformerLevel::Level1,
5708+ /* opset_version=*/ 20 ,
5709+ /* per_sample_tolerance=*/ 0.01 ,
5710+ /* relative_per_sample_tolerance=*/ 0.01 ,
5711+ /* transformer=*/ std::make_unique<WeightBiasQuantization>());
5712+ };
5713+
5714+ test_case (false );
5715+ test_case (true );
5716+ }
5717+
56725718TEST (QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) {
56735719 auto test_case = [](bool use_contrib_qdq) {
56745720 auto build_test_case = [&](ModelTestBuilder& builder) {
0 commit comments