Skip to content
71 changes: 27 additions & 44 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,55 +282,38 @@ Tensor& quantize_per_channel_out(

check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);

// a list contains all dimensions except axis
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
if (i < axis) {
dims[i] = i;
} else {
dims[i] = i - 1;
}
}
const double* scale_data = scale.const_data_ptr<double>();
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();

std::optional<executorch::aten::ArrayRef<int64_t>> optional_dim_list{
executorch::aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};

// Actual quantization logic
// input, out are the input and output tensors
// channel_ix is the index along the axis dimension. 0 <= channel_ix <
// input.size(axis).
// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
// will be 0, 1, 2, ... C-1
// in_ix is the flat index of the element you are quantizing.
// in other words you are quantizing in_data[in_ix]
// High-performance single loop with direct channel calculation
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: \
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
double _scale = scale_data[channel_ix]; \
int64_t _zero_point = zero_point_data[channel_ix]; \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
apply_over_dim_list( \
[input_data_ptr, \
out_data_ptr, \
_scale, \
_zero_point, \
quant_min, \
quant_max](size_t in_ix) { \
out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, \
_zero_point, \
input_data_ptr[in_ix], \
quant_min, \
quant_max); \
}, \
input, \
optional_dim_list, \
channel_ix); \
case ScalarType::out_dtype: { \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
const int64_t input_numel = input.numel(); \
const int64_t axis_size = input.size(axis); \
\
/* Calculate the stride pattern for efficient channel index calculation */ \
int64_t axis_block_size = 1; \
for (int64_t i = axis + 1; i < input.dim(); i++) { \
axis_block_size *= input.size(i); \
} \
break;
\
/* Single loop over all elements */ \
for (int64_t i = 0; i < input_numel; i++) { \
/* Calculate which channel this element belongs to */ \
int64_t channel_idx = (i / axis_block_size) % axis_size; \
\
/* Get quantization parameters for this channel */ \
double _scale = scale_data[channel_idx]; \
int64_t _zero_point = zero_point_data[channel_idx]; \
\
/* Apply quantization */ \
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
} \
} break;

#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
case ScalarType::in_dtype: \
switch (out.scalar_type()) { \
Expand Down
240 changes: 240 additions & 0 deletions kernels/quantized/test/op_quantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) {

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({3, 2}, 4);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
Tensor zero_point = tf_long.make({3}, {100, 50, 25});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({3, 2});
// Channel 0: 4 / 0.5 + 100 = 108
// Channel 1: 4 / 1.0 + 50 = 54
// Channel 2: 4 / 2.0 + 25 = 27
Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27});
quantize_per_channel_out(
input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannel3D) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test 3D tensor with axis=1 (middle dimension)
Tensor input = tf_float.full({2, 3, 4}, 6);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 3, 4});
// Channel 0: 6 / 0.5 + 10 = 22
// Channel 1: 6 / 1.0 + 20 = 26
// Channel 2: 6 / 1.5 + 30 = 34
Tensor expected = tfo.make(
{2, 3, 4},
{
22, 22, 22, 22, // First batch, channel 0
26, 26, 26, 26, // First batch, channel 1
34, 34, 34, 34, // First batch, channel 2
22, 22, 22, 22, // Second batch, channel 0
26, 26, 26, 26, // Second batch, channel 1
34, 34, 34, 34 // Second batch, channel 2
});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannel4D) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W)
Tensor input = tf_float.full({2, 2, 3, 2}, 8);
Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0});
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 2, 3, 2});
// Channel 0: 8 / 0.25 + 0 = 32
// Channel 1: 8 / 0.5 + 10 = 26
// Channel 2: 8 / 1.0 + 20 = 28
std::vector<int8_t> expected_data;
for (int n = 0; n < 2; n++) {
for (int c = 0; c < 2; c++) {
for (int h = 0; h < 3; h++) {
for (int w = 0; w < 2; w++) {
int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28;
expected_data.push_back(val);
}
}
}
}
Tensor expected = tfo.make({2, 2, 3, 2}, expected_data);
quantize_per_channel_out(
input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({2, 3}, 5);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({2, 3});
// Using axis=-1 should be equivalent to axis=1 for 2D tensor
// Channel 0: 5 / 0.5 + 0 = 10
// Channel 1: 5 / 1.0 + 10 = 15
// Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5)
Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22});
quantize_per_channel_out(
input,
scale,
zero_point,
-1,
quant_min,
quant_max,
ScalarType::Byte,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({3, 1, 4}, 7);
Tensor scale = tf_double.make({1}, {0.5});
Tensor zero_point = tf_long.make({1}, {128});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({3, 1, 4});
// Single channel: 7 / 0.5 + 128 = 142
Tensor expected = tfo.full({3, 1, 4}, 142);
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) {
TensorFactory<ScalarType::Double> tf_double_input;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_double_input.full({2, 2}, 3.14159);
Tensor scale = tf_double.make({2}, {0.01, 0.02});
Tensor zero_point = tf_long.make({2}, {0, 100});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 2});
// Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127
// Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127
Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({2, 2}, 10);
Tensor scale = tf_double.make({2}, {1.0, 2.0});
Tensor zero_point = tf_long.make({2}, {1000, 2000});
int64_t quant_min = -32768;
int64_t quant_max = 32767;

// Test with 16-bit output
TensorFactory<ScalarType::Short> tfo;
Tensor out = tfo.zeros({2, 2});
// Channel 0: 10 / 1.0 + 1000 = 1010
// Channel 1: 10 / 2.0 + 2000 = 2005
Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005});
quantize_per_channel_out(
input,
scale,
zero_point,
1,
quant_min,
quant_max,
ScalarType::Short,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test with different input values per position
Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({2, 3});
// Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32]
// Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34]
Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test values that will exceed quant_min/quant_max bounds
Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0});
Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0});
Tensor zero_point = tf_long.make({3}, {0, 0, 0});
int64_t quant_min = -10;
int64_t quant_max = 10;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({1, 3});
// Values: [-100, 0, 100] should be clamped to [-10, 0, 10]
Tensor expected = tfo.make({1, 3}, {-10, 0, 10});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}
Loading