Skip to content
4 changes: 2 additions & 2 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out(
break;

switch (input.scalar_type()) {
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
default:
ET_CHECK_MSG(
false,
Expand Down Expand Up @@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out(
break;

switch (input.scalar_type()) {
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
default:
ET_CHECK_MSG(
false,
Expand Down
65 changes: 65 additions & 0 deletions kernels/quantized/test/op_quantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ void test_dtype() {
EXPECT_TENSOR_EQ(out, expected);
}

template <ScalarType INPUT_DTYPE>
void test_input_dtype() {
TensorFactory<INPUT_DTYPE> tf_input;

Tensor input = tf_input.full({3, 5}, 4);
double scale = 0.5;
int64_t zero_point = 108;
int64_t quant_min = 0;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({3, 5});
// 4 / 0.5 + 108 = 116
Tensor expected = tfo.full({3, 5}, 116);
quantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, AllInputDtypesSupported) {
test_input_dtype<ScalarType::Float>();
test_input_dtype<ScalarType::Half>();
test_input_dtype<ScalarType::Double>();
}

TEST(OpQuantizeOutTest, AllDtypesSupported) {
test_dtype<ScalarType::Byte>();
test_dtype<ScalarType::Char>();
Expand All @@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
test_dtype<ScalarType::Int>();
}

TEST(OpQuantizeOutTest, DoubleInputTest) {
TensorFactory<ScalarType::Double> tf_double;

// Test with a more complex value that might have precision differences
Tensor input = tf_double.full({2, 3}, 3.14159265359);
double scale = 0.01;
int64_t zero_point = -100;
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({2, 3});
// 3.14159265359 / 0.01 - 100 = 214.159265359
Tensor expected = tfo.full({2, 3}, 214);
quantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, HalfInputTest) {
TensorFactory<ScalarType::Half> tf_half;

Tensor input = tf_half.full({2, 3}, 2.5);
double scale = 0.5;
int64_t zero_point = 10;
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 3});
// 2.5 / 0.5 + 10 = 15
Tensor expected = tfo.full({2, 3}, 15);
quantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, TensorArgOverload) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
Expand Down
Loading