Skip to content

Commit 2825849

Browse files
author
morelos
committed
Update base for Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
1 parent 8abd26a commit 2825849

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out(
150150
break;
151151

152152
switch (input.scalar_type()) {
153-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
153+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
154154
default:
155155
ET_CHECK_MSG(
156156
false,
@@ -347,7 +347,7 @@ Tensor& quantize_per_channel_out(
347347
break;
348348

349349
switch (input.scalar_type()) {
350-
ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE);
350+
ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE);
351351
default:
352352
ET_CHECK_MSG(
353353
false,

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ void test_dtype() {
4949
EXPECT_TENSOR_EQ(out, expected);
5050
}
5151

52+
template <ScalarType INPUT_DTYPE>
53+
void test_input_dtype() {
54+
TensorFactory<INPUT_DTYPE> tf_input;
55+
56+
Tensor input = tf_input.full({3, 5}, 4);
57+
double scale = 0.5;
58+
int64_t zero_point = 108;
59+
int64_t quant_min = 0;
60+
int64_t quant_max = 127;
61+
62+
TensorFactory<ScalarType::Char> tfo;
63+
Tensor out = tfo.zeros({3, 5});
64+
// 4 / 0.5 + 108 = 116
65+
Tensor expected = tfo.full({3, 5}, 116);
66+
quantize_per_tensor_out(
67+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
68+
69+
EXPECT_TENSOR_EQ(out, expected);
70+
}
71+
72+
TEST(OpQuantizeOutTest, AllInputDtypesSupported) {
73+
test_input_dtype<ScalarType::Float>();
74+
test_input_dtype<ScalarType::Half>();
75+
test_input_dtype<ScalarType::Double>();
76+
}
77+
5278
TEST(OpQuantizeOutTest, AllDtypesSupported) {
5379
test_dtype<ScalarType::Byte>();
5480
test_dtype<ScalarType::Char>();
@@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) {
5884
test_dtype<ScalarType::Int>();
5985
}
6086

87+
TEST(OpQuantizeOutTest, DoubleInputTest) {
88+
TensorFactory<ScalarType::Double> tf_double;
89+
90+
// Test with a more complex value that might have precision differences
91+
Tensor input = tf_double.full({2, 3}, 3.14159265359);
92+
double scale = 0.01;
93+
int64_t zero_point = -100;
94+
int64_t quant_min = 0;
95+
int64_t quant_max = 255;
96+
97+
TensorFactory<ScalarType::Byte> tfo;
98+
Tensor out = tfo.zeros({2, 3});
99+
// 3.14159265359 / 0.01 - 100 = 214.159265359
100+
Tensor expected = tfo.full({2, 3}, 214);
101+
quantize_per_tensor_out(
102+
input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out);
103+
104+
EXPECT_TENSOR_EQ(out, expected);
105+
}
106+
107+
TEST(OpQuantizeOutTest, HalfInputTest) {
108+
TensorFactory<ScalarType::Half> tf_half;
109+
110+
Tensor input = tf_half.full({2, 3}, 2.5);
111+
double scale = 0.5;
112+
int64_t zero_point = 10;
113+
int64_t quant_min = -128;
114+
int64_t quant_max = 127;
115+
116+
TensorFactory<ScalarType::Char> tfo;
117+
Tensor out = tfo.zeros({2, 3});
118+
// 2.5 / 0.5 + 10 = 15
119+
Tensor expected = tfo.full({2, 3}, 15);
120+
quantize_per_tensor_out(
121+
input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out);
122+
123+
EXPECT_TENSOR_EQ(out, expected);
124+
}
125+
61126
TEST(OpQuantizeOutTest, TensorArgOverload) {
62127
TensorFactory<ScalarType::Float> tf_float;
63128
TensorFactory<ScalarType::Double> tf_double;

0 commit comments

Comments
 (0)