Skip to content

Commit 6bc1c7b

Browse files
committed
fix: Mixed type quantized matmul for after configure quantization update
If the quantization information is required to be updated after configure, in some IPs, the mixed data type support does not return the correct results. This is because if there is no native mixed quantized kernel support in arm_gemm, the operator changes the sign of Lhs. This wasn't accounted correctly in the after configure update. Resolves: COMPMID-7833 Change-Id: I88bb88921c569a9ef3b4aa49d09015f9e10247e8 Signed-off-by: Gunes Bayir <[email protected]> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14670 Benchmark: Arm Jenkins <[email protected]> Reviewed-by: Dennis Wildmark <[email protected]> Comments-Addressed: Arm Jenkins <[email protected]> Tested-by: Arm Jenkins <[email protected]>
1 parent 520608e commit 6bc1c7b

File tree

4 files changed

+52
-15
lines changed

4 files changed

+52
-15
lines changed

src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,21 @@ void CpuGemmLowpMatrixMultiplyCore::update_quantization_parameters(const GEMMLow
810810
{
811811
auto lowp_os = output_info;
812812
_gemm_info.set_gemmlowp_output_stage(lowp_os);
813-
_asm_glue->update_quantization_parameters(output_info, a, b, is_prepared, negated_offsets);
813+
814+
const QuantizationInfo *a_to_use = &a;
815+
QuantizationInfo a_signed;
816+
817+
if (_flip_signedness)
818+
{
819+
const int32_t offset_correction = 128;
820+
const UniformQuantizationInfo a_uniform = a.uniform();
821+
822+
ARM_COMPUTE_ERROR_ON(a.scale().size() > 1);
823+
a_signed = QuantizationInfo(a_uniform.scale, a_uniform.offset + offset_correction);
824+
a_to_use = &a_signed;
825+
}
826+
827+
_asm_glue->update_quantization_parameters(output_info, *a_to_use, b, is_prepared, negated_offsets);
814828
_is_prepared = is_prepared;
815829
}
816830
} // namespace cpu

src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ void NEGEMMLowpMatrixMultiplyCore::update_quantization_parameters()
146146
output_info.gemmlowp_max_bound = max_activation;
147147
output_info.is_quantized_per_channel = false;
148148
output_info.output_data_type = dst->info()->data_type();
149-
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
149+
const Status status = quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
150+
ARM_COMPUTE_ERROR_ON(!bool(status));
150151

151152
_impl->op->update_quantization_parameters(output_info, src->info()->quantization_info(),
152153
wei->info()->quantization_info(), true, true);

tests/validation/NEON/GEMMLowp.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,8 @@ using NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureInt8Fix
383383
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
384384
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureInt8Fixture, framework::DatasetMode::ALL,
385385
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
386-
make("DataType", { DataType::QASYMM8_SIGNED }),
386+
make("DataTypeA", { DataType::QASYMM8_SIGNED }),
387+
make("DataTypeB", { DataType::QASYMM8_SIGNED }),
387388
make("reshape_b_only_on_first_run", { false }),
388389
make("updated_sq_info_after_config", { true }),
389390
QuantizedActivationFunctionsDataset
@@ -393,7 +394,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQua
393394
}
394395
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureInt8Fixture, framework::DatasetMode::NIGHTLY,
395396
combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
396-
make("DataType", { DataType::QASYMM8_SIGNED }),
397+
make("DataTypeA", { DataType::QASYMM8_SIGNED }),
398+
make("DataTypeB", { DataType::QASYMM8_SIGNED }),
397399
make("reshape_b_only_on_first_run", { false }),
398400
make("updated_sq_info_after_config", { true }),
399401
QuantizedActivationFunctionsDataset
@@ -408,7 +410,8 @@ using NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureUInt8Fi
408410
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
409411
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureUInt8Fixture, framework::DatasetMode::ALL,
410412
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
411-
make("DataType", { DataType::QASYMM8 }),
413+
make("DataTypeA", { DataType::QASYMM8 }),
414+
make("DataTypeB", { DataType::QASYMM8 }),
412415
make("reshape_b_only_on_first_run", { false }),
413416
make("updated_sq_info_after_config", { true }),
414417
QuantizedActivationFunctionsDataset
@@ -418,7 +421,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQua
418421
}
419422
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureUInt8Fixture, framework::DatasetMode::NIGHTLY,
420423
combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
421-
make("DataType", { DataType::QASYMM8 }),
424+
make("DataTypeA", { DataType::QASYMM8 }),
425+
make("DataTypeB", { DataType::QASYMM8 }),
422426
make("reshape_b_only_on_first_run", { false }),
423427
make("updated_sq_info_after_config", { true }),
424428
QuantizedActivationFunctionsDataset
@@ -427,6 +431,22 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQua
427431
validate(Accessor(_target), _reference, tolerance_batched, large_test_tolerance_num);
428432
}
429433
TEST_SUITE_END() // QASYMM8
434+
435+
TEST_SUITE(MixedQuantizedType)
436+
using NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureInt8Fixture =
437+
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, int8_t, true>;
438+
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreForUpdatedStaticQuantInfoAfterConfigureInt8Fixture, framework::DatasetMode::ALL,
439+
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
440+
make("DataTypeA", { DataType::QASYMM8 }),
441+
make("DataTypeB", { DataType::QASYMM8_SIGNED }),
442+
make("reshape_b_only_on_first_run", { false }),
443+
make("updated_sq_info_after_config", { true }),
444+
QuantizedActivationFunctionsDataset
445+
))
446+
{
447+
validate(Accessor(_target), _reference, tolerance_batched);
448+
}
449+
TEST_SUITE_END() // MixedQuantizedType
430450
TEST_SUITE_END() // UpdateStaticQuantInfoAfterConfigure
431451

432452
// Deqaunt tests involve returning FP32 from the MatrixMultiplyCore kernels and is only implemented in aarch64

tests/validation/fixtures/GEMMLowpFixture.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ TensorType compute_gemmlowp_target_for_updated_sq_info_after_config(const Tensor
111111
{
112112
ARM_COMPUTE_ASSERT((std::is_same<FunctionType, NEGEMMLowpMatrixMultiplyCore>::value == true));
113113
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
114-
ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
114+
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_b));
115115

116116
// If unknown, set to sensible defaults
117117
if (data_type_output == DataType::UNKNOWN) {
@@ -531,25 +531,26 @@ class GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture : publ
531531
* 2. The data type is quantized asymmetric
532532
*
533533
*/
534-
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type,
534+
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type_a, DataType data_type_b,
535535
bool reshape_b_only_on_first_run, bool updated_sq_info_after_config = false, const ActivationLayerInfo& act_info = ActivationLayerInfo())
536536
{
537537
ARM_COMPUTE_ASSERT(output_stage_type != GEMMLowpOutputStageType::NONE);
538-
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type));
538+
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
539+
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_b));
539540

540541
// Randomized dynamic quantization: randomize quantization info in a way that ensures no result saturation
541542
// most of the time
542543
QuantizationInfo a_qinfo;
543544
QuantizationInfo b_qinfo;
544545
QuantizationInfo output_qinfo;
545546
TensorFillInfo finfo;
546-
setup_quantization<TI>(data_type, shape_a, shape_b, a_qinfo, b_qinfo, output_qinfo, finfo);
547+
setup_quantization<TI>(data_type_a, shape_a, shape_b, a_qinfo, b_qinfo, output_qinfo, finfo);
547548

548549
GEMMLowpOutputStageInfo output_stage;
549-
init_gemmlowp_output_stage_info(data_type, a_qinfo, b_qinfo, output_qinfo, act_info, output_stage_type, output_stage);
550+
init_gemmlowp_output_stage_info(data_type_a, a_qinfo, b_qinfo, output_qinfo, act_info, output_stage_type, output_stage);
550551

551-
_reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type, data_type, output_stage, finfo);
552-
_target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type, data_type, output_stage, reshape_b_only_on_first_run, finfo, updated_sq_info_after_config, act_info);
552+
_reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, output_stage, finfo);
553+
_target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type_a, data_type_b, output_stage, reshape_b_only_on_first_run, finfo, updated_sq_info_after_config, act_info);
553554
}
554555

555556
protected:
@@ -687,7 +688,7 @@ class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public GEMM
687688
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
688689
{
689690
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b,
690-
shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
691+
shape_output, output_stage_type, data_type, data_type, reshape_b_only_on_first_run);
691692
}
692693
};
693694

@@ -697,7 +698,8 @@ class GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture : public GEMMLow
697698
public:
698699
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
699700
{
700-
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b, shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
701+
GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
702+
::setup(shape_a, shape_b, shape_output, output_stage_type, data_type, data_type, reshape_b_only_on_first_run);
701703
}
702704
};
703705

0 commit comments

Comments
 (0)