Skip to content

Commit 4121e3e

Browse files
author
morelos
committed
Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
2 parents 7a5b348 + f5229f9 commit 4121e3e

File tree

3 files changed

+119
-22
lines changed

3 files changed

+119
-22
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out(
288288
static_cast<float>(scale)); \
289289
} \
290290
} break;
291-
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292-
case ScalarType::in_dtype: \
293-
switch (out.scalar_type()) { \
294-
ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295-
default: \
296-
ET_CHECK_MSG( \
297-
false, \
298-
"Unhandled output dtype %" PRId8, \
299-
static_cast<int8_t>(out.scalar_type())); \
300-
} \
291+
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
292+
case ScalarType::in_dtype: \
293+
switch (out.scalar_type()) { \
294+
ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
295+
default: \
296+
ET_CHECK_MSG( \
297+
false, \
298+
"Unhandled output dtype %" PRId8, \
299+
static_cast<int8_t>(out.scalar_type())); \
300+
} \
301301
break;
302302

303303
switch (input.scalar_type()) {
@@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out(
459459
} \
460460
out_data_ptr[current_ix] = \
461461
static_cast<CTYPE_OUT>( \
462-
input_data_ptr[current_ix] - zero_point) * \
462+
input_data_ptr[current_ix] - \
463+
static_cast<int32_t>(zero_point)) * \
463464
_scale; \
464465
} \
465466
}, \
@@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out(
478479
apply_over_dim_list( \
479480
[input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \
480481
out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \
481-
(input_data_ptr[in_ix] - _zero_point) * _scale); \
482+
(input_data_ptr[in_ix] - static_cast<int32_t>(_zero_point)) * \
483+
_scale); \
482484
}, \
483485
input, \
484486
optional_dim_list, \
485487
channel_ix); \
486488
} \
487489
break;
488-
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
489-
case ScalarType::in_dtype: \
490-
switch (out.scalar_type()) { \
491-
ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
492-
default: \
493-
ET_CHECK_MSG( \
494-
false, \
495-
"Unhandled output dtype %" PRId8, \
496-
static_cast<int8_t>(out.scalar_type())); \
497-
} \
490+
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
491+
case ScalarType::in_dtype: \
492+
switch (out.scalar_type()) { \
493+
ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
494+
default: \
495+
ET_CHECK_MSG( \
496+
false, \
497+
"Unhandled output dtype %" PRId8, \
498+
static_cast<int8_t>(out.scalar_type())); \
499+
} \
498500
break;
499501

500502
switch (input.scalar_type()) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) {
6767
test_dtype<ScalarType::Int>();
6868
}
6969

70+
/// Test all supported output dtypes for dequantization
71+
template <ScalarType OUT_DTYPE>
72+
void test_output_dtype() {
73+
TensorFactory<ScalarType::Byte> tf;
74+
75+
Tensor input = tf.full({3, 5}, 100);
76+
double scale = 0.5;
77+
int64_t zero_point = 30;
78+
int64_t quant_min = 0;
79+
int64_t quant_max = 255;
80+
81+
TensorFactory<OUT_DTYPE> tfo;
82+
Tensor out = tfo.zeros({3, 5});
83+
// (100 - 30) * 0.5 = 35
84+
Tensor expected = tfo.full({3, 5}, 35);
85+
dequantize_per_tensor_out(
86+
input,
87+
scale,
88+
zero_point,
89+
quant_min,
90+
quant_max,
91+
ScalarType::Byte,
92+
optional<ScalarType>(OUT_DTYPE),
93+
out);
94+
95+
EXPECT_TENSOR_EQ(out, expected);
96+
}
97+
98+
TEST(OpDequantizeOutTest, AllOutputDtypesSupported) {
99+
et_pal_init();
100+
test_output_dtype<ScalarType::Float>();
101+
test_output_dtype<ScalarType::Double>();
102+
test_output_dtype<ScalarType::Half>();
103+
}
104+
105+
TEST(OpDequantizeOutTest, HalfOutput) {
106+
et_pal_init();
107+
TensorFactory<ScalarType::Byte> tf;
108+
109+
Tensor input = tf.full({3, 5}, 10);
110+
double scale = 0.5;
111+
int64_t zero_point = 100000;
112+
int64_t quant_min = 0;
113+
int64_t quant_max = 255;
114+
115+
TensorFactory<ScalarType::Half> tfo;
116+
Tensor out = tfo.zeros({3, 5});
117+
// (10 - 100000) * 0.5 = -49995
118+
dequantize_per_tensor_out(
119+
input,
120+
scale,
121+
zero_point,
122+
quant_min,
123+
quant_max,
124+
ScalarType::Byte,
125+
optional<ScalarType>(ScalarType::Half),
126+
out);
127+
128+
// The expected result should be (10 - 100000) * 0.5 = -49995
129+
Tensor expected = tfo.full({3, 5}, -49995);
130+
EXPECT_TENSOR_EQ(out, expected);
131+
}
132+
133+
TEST(OpDequantizeOutTest, DoubleOutput) {
134+
et_pal_init();
135+
TensorFactory<ScalarType::Byte> tf;
136+
137+
Tensor input = tf.full({3, 5}, 10);
138+
double scale = 0.5;
139+
int64_t zero_point = 100000;
140+
int64_t quant_min = 0;
141+
int64_t quant_max = 255;
142+
143+
TensorFactory<ScalarType::Double> tfo;
144+
Tensor out = tfo.zeros({3, 5});
145+
dequantize_per_tensor_out(
146+
input,
147+
scale,
148+
zero_point,
149+
quant_min,
150+
quant_max,
151+
ScalarType::Byte,
152+
optional<ScalarType>(ScalarType::Double),
153+
out);
154+
155+
// The expected result should be (10 - 100000) * 0.5 = -49995
156+
Tensor expected = tfo.full({3, 5}, -49995);
157+
EXPECT_TENSOR_EQ(out, expected);
158+
}
159+
70160
TEST(OpDequantizeOutTest, NonWholeNumbers) {
71161
et_pal_init();
72162
TensorFactory<ScalarType::Byte> tf;

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
199199
_(ANOTHER_INPUT, float, Float) \
200200
_(ANOTHER_INPUT, double, Double)
201201

202+
#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \
203+
_(ANOTHER_INPUT, float, Float) \
204+
_(ANOTHER_INPUT, double, Double) \
205+
_(ANOTHER_INPUT, ::executorch::aten::Half, Half)
206+
202207
#define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \
203208
_(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \
204209
_(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)

0 commit comments

Comments
 (0)