11/*
2- * Copyright (c) 2017-2021, 2024 Arm Limited.
2+ * Copyright (c) 2017-2021, 2024-2025 Arm Limited.
33 *
44 * SPDX-License-Identifier: MIT
55 *
@@ -90,20 +90,44 @@ TEST_CASE(ProperlyRoundedRequantization, framework::DatasetMode::ALL)
9090 validate (Accessor (output), ref, zero_tolerance_s8);
9191}
9292
93+ TEST_CASE (QSymm8_per_channel_validate_scales, framework::DatasetMode::ALL)
94+ {
95+ // In this test we make sure validate does not raise an error when we pass a properly initialized vector of scales matching
96+ // the number of channels
97+ const auto input_info = TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F32);
98+ auto output_info = TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL);
99+ Tensor input = create_tensor<Tensor>(input_info);
100+ std::vector<float > scale (16 ,0 .5f );
101+ Tensor output = create_tensor<Tensor>(output_info.tensor_shape (), DataType::QSYMM8_PER_CHANNEL, 1 , QuantizationInfo (scale));
102+ ARM_COMPUTE_EXPECT (bool (NEQuantizationLayer::validate (
103+ & input.info ()->clone ()->set_is_resizable (false ),
104+ & output.info ()->clone ()->set_is_resizable (false ))) == true , framework::LogLevel::ERRORS);
105+ }
106+
93107// *INDENT-OFF*
94108// clang-format off
95109DATA_TEST_CASE (Validate, framework::DatasetMode::ALL, zip(zip(
96110 framework::dataset::make (" InputInfo" , { TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QASYMM8), // Wrong output data type
97111 TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F32), // Wrong output data type
98112 TensorInfo (TensorShape (16U , 16U , 2U , 5U ), 1 , DataType::F32), // Missmatching shapes
99113 TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F32), // Valid
114+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QASYMM8), // PER_CHANNEL only supported for F32
115+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8), // PER_CHANNEL only supported for F32
116+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM16), // PER_CHANNEL only supported for F32
117+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F16), // PER_CHANNEL only supported for F32
118+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F32), // Quantization info's scales not initialized
100119 }),
101120 framework::dataset::make(" OutputInfo" ,{ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::F32),
102121 TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::U16),
103122 TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QASYMM8),
104123 TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QASYMM8),
124+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL),
125+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL),
126+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL),
127+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL),
128+ TensorInfo (TensorShape (16U , 16U , 16U , 5U ), 1 , DataType::QSYMM8_PER_CHANNEL),
105129 })),
106- framework::dataset::make(" Expected" , { false , false , false , true })),
130+ framework::dataset::make(" Expected" , { false , false , false , true , false , false , false , false , false })),
107131 input_info, output_info, expected)
108132{
109133 ARM_COMPUTE_EXPECT (bool (NEQuantizationLayer::validate (&input_info.clone ()->set_is_resizable (false ), &output_info.clone ()->set_is_resizable (false ))) == expected, framework::LogLevel::ERRORS);
@@ -117,6 +141,8 @@ template <typename T>
117141using NEQuantizationLayerQASYMM8SignedFixture = QuantizationValidationFixture<Tensor, Accessor, NEQuantizationLayer, T, int8_t >;
118142template <typename T>
119143using NEQuantizationLayerQASYMM16Fixture = QuantizationValidationFixture<Tensor, Accessor, NEQuantizationLayer, T, uint16_t >;
144+ template <typename T>
145+ using NEQuantizationLayerQSYMM8_PER_CHANNEL_Fixture = QuantizationValidationFixture<Tensor, Accessor, NEQuantizationLayer, T, int8_t >;
120146
121147TEST_SUITE (Float)
122148TEST_SUITE(FP32)
@@ -160,6 +186,17 @@ FIXTURE_DATA_TEST_CASE(RunLargeQASYMM16, NEQuantizationLayerQASYMM16Fixture<floa
160186 // Validate output
161187 validate (Accessor (_target), _reference, tolerance_u16);
162188}
189+
190+
191+ FIXTURE_DATA_TEST_CASE (RunSmallQSYMM8_PER_CHANNEL, NEQuantizationLayerQSYMM8_PER_CHANNEL_Fixture<float >, framework::DatasetMode::PRECOMMIT, combine(combine(combine(QuantizationSmallShapes,
192+ framework::dataset::make (" DataType" , DataType::F32)),
193+ framework::dataset::make(" DataTypeOut" , { DataType::QSYMM8_PER_CHANNEL })),
194+ framework::dataset::make(" QuantizationInfoIgnored" , { QuantizationInfo () })))
195+ {
196+ // Validate output
197+ validate (Accessor (_target), _reference, tolerance_s8);
198+ }
199+
163200TEST_SUITE_END () // FP32
164201#ifdef ARM_COMPUTE_ENABLE_FP16
165202TEST_SUITE (FP16)
0 commit comments