Skip to content

Commit 68b449b

Browse files
committed
[Executorch][quant] Optimize per channel dequantize
Pull Request resolved: #5670 When using quantized kv cache, dequantization routine takes significantly long. This diff just vectorizes dequant per channel for common case. ghstack-source-id: 253887459 @exported-using-ghexport Differential Revision: [D63338858](https://our.internmc.facebook.com/intern/diff/D63338858/)
1 parent 54899fe commit 68b449b

File tree

2 files changed

+238
-21
lines changed

2 files changed

+238
-21
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 196 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <algorithm>
1212
#include <cinttypes>
1313
#include <cmath>
14+
#if defined(__aarch64__) || defined(__ARM_NEON)
15+
#include <arm_neon.h>
16+
#endif
1417

1518
/**
1619
* For an input tensor, use the scale and zero_point arguments to quantize it.
@@ -22,6 +25,8 @@ namespace native {
2225
using Tensor = exec_aten::Tensor;
2326
using Scalar = exec_aten::Scalar;
2427
using ScalarType = exec_aten::ScalarType;
28+
using StridesType = exec_aten::StridesType;
29+
using SizesType = exec_aten::SizesType;
2530

2631
namespace {
2732

@@ -62,6 +67,183 @@ void check_dequantize_per_tensor_args(
6267
quant_max);
6368
}
6469

70+
/**
71+
* Useful to reduce a tensor `in` over a given dimension `dim` using the
72+
* reduce function `fn`, which should have the following signature:
73+
* void fn(const size_t size, const size_t stride, const size_t base_ix)
74+
* where `size` and `stride` are the size and stride of the dimension being
75+
* reduced and `base_ix` is the index of the first element of the reduction.
76+
*/
77+
template <typename Fn>
78+
void apply_over_unpacked_dim(
79+
const Fn& fn,
80+
const exec_aten::Tensor& in,
81+
const int64_t& dim) {
82+
if (in.numel() == 0) {
83+
return;
84+
}
85+
86+
ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
87+
ET_CHECK_VALID_DIM(dim, in.dim());
88+
89+
const size_t d = ET_NORMALIZE_IX(dim, in.dim());
90+
const size_t dim_size = in.size(d);
91+
const size_t outer_size = getLeadingDims(in, d);
92+
const size_t inner_size = getTrailingDims(in, d);
93+
// Loop through all outer dimensions
94+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
95+
// Loop through dim
96+
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
97+
++unpacked_dim_idx) {
98+
fn(inner_size, outer_idx, unpacked_dim_idx);
99+
}
100+
}
101+
}
102+
103+
void dequantize_optimized(
104+
const int8_t* in,
105+
const double scale,
106+
const int64_t zero_point,
107+
float* out,
108+
int64_t quant_min,
109+
int64_t quant_max,
110+
size_t numel) {
111+
ET_CHECK_MSG(
112+
zero_point >= quant_min,
113+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
114+
zero_point,
115+
quant_min);
116+
ET_CHECK_MSG(
117+
zero_point <= quant_max,
118+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
119+
zero_point,
120+
quant_max);
121+
size_t i = 0;
122+
#if defined(__aarch64__) || defined(__ARM_NEON)
123+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
124+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
125+
constexpr int32_t kVecSize = 16;
126+
const size_t num_vecs = numel / kVecSize;
127+
const int8_t* in_copy = in;
128+
float* out_copy = out;
129+
for (; i < num_vecs; i++) {
130+
int8x16_t in_vec = vld1q_s8(in_copy);
131+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
132+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
133+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
134+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
135+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
136+
137+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
138+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
139+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
140+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
141+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
142+
vst1q_f32(out_copy + 0, out_vec_0_3);
143+
vst1q_f32(out_copy + 4, out_vec_4_7);
144+
vst1q_f32(out_copy + 8, out_vec_8_11);
145+
vst1q_f32(out_copy + 12, out_vec_12_15);
146+
in_copy += kVecSize;
147+
out_copy += kVecSize;
148+
}
149+
i = i * kVecSize;
150+
#endif
151+
for (; i < numel; i++) {
152+
out[i] = (in[i] - zero_point) * scale;
153+
}
154+
}
155+
156+
float get_scale(const Tensor& scale, size_t channel_ix) {
157+
ET_CHECK_MSG(
158+
(scale.scalar_type() == ScalarType::Double) ||
159+
(scale.scalar_type() == ScalarType::Float),
160+
"scale.scalar_type() %" PRId8 " is not double or float type",
161+
static_cast<int8_t>(scale.scalar_type()));
162+
if (scale.scalar_type() == ScalarType::Double) {
163+
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
164+
} else {
165+
return scale.const_data_ptr<float>()[channel_ix];
166+
}
167+
}
168+
169+
bool can_use_optimized_dequantize_per_channel(
170+
const Tensor& in,
171+
const ScalarType in_dtype,
172+
exec_aten::optional<ScalarType>& out_dtype) {
173+
bool is_contiguous = false;
174+
#ifdef USE_ATEN_LIB
175+
is_contiguous = in.is_contiguous();
176+
#else
177+
is_contiguous = executorch::runtime::is_contiguous_dim_order(
178+
in.dim_order().data(), in.dim());
179+
#endif
180+
if (!is_contiguous || (in_dtype != ScalarType::Char) ||
181+
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
182+
return false;
183+
}
184+
return true;
185+
}
186+
187+
void dequantize_per_channel_optimized(
188+
const Tensor& in,
189+
const Tensor& scales,
190+
const optional<Tensor>& opt_zero_points,
191+
Tensor& out,
192+
int64_t axis,
193+
int64_t quant_min,
194+
int64_t quant_max,
195+
ScalarType in_dtype,
196+
exec_aten::optional<ScalarType>& out_dtype) {
197+
check_dequantize_per_tensor_args(
198+
in, quant_min, quant_max, in_dtype, out_dtype, out);
199+
ET_CHECK_MSG(
200+
in_dtype == ScalarType::Char,
201+
"in.scalar_type() %" PRId8 " is not supported:",
202+
static_cast<int8_t>(in.scalar_type()));
203+
if (out_dtype.has_value()) {
204+
ET_CHECK_MSG(
205+
out_dtype.value() == ScalarType::Float,
206+
"Only float output is supported");
207+
}
208+
const int8_t* in_data = in.const_data_ptr<int8_t>();
209+
float* out_data = out.mutable_data_ptr<float>();
210+
const int64_t* zero_points_data = nullptr;
211+
if (opt_zero_points.has_value()) {
212+
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
213+
}
214+
const StridesType axis_stride = in.strides()[axis];
215+
const StridesType outer_stride = in.size(axis) * axis_stride;
216+
apply_over_unpacked_dim(
217+
[in_data,
218+
out_data,
219+
&scales,
220+
zero_points_data,
221+
axis_stride,
222+
outer_stride,
223+
quant_min,
224+
quant_max](
225+
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
226+
const int8_t* in_data_local =
227+
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
228+
const double scale = get_scale(scales, unpacked_dim_idx);
229+
const int64_t zero_point = zero_points_data != nullptr
230+
? zero_points_data[unpacked_dim_idx]
231+
: 0;
232+
float* out_data_local = out_data + outer_idx * outer_stride +
233+
unpacked_dim_idx * axis_stride;
234+
dequantize_optimized(
235+
in_data_local,
236+
scale,
237+
zero_point,
238+
out_data_local,
239+
quant_min,
240+
quant_max,
241+
numel);
242+
},
243+
in,
244+
axis);
245+
}
246+
65247
} // namespace
66248

67249
/**
@@ -170,19 +352,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
170352
return out;
171353
}
172354

173-
float get_scale(const Tensor& scale, size_t channel_ix) {
174-
ET_CHECK_MSG(
175-
(scale.scalar_type() == ScalarType::Double) ||
176-
(scale.scalar_type() == ScalarType::Float),
177-
"scale.scalar_type() %" PRId8 " is not double or float type",
178-
static_cast<int8_t>(scale.scalar_type()));
179-
if (scale.scalar_type() == ScalarType::Double) {
180-
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
181-
} else {
182-
return scale.const_data_ptr<float>()[channel_ix];
183-
}
184-
}
185-
186355
Tensor& dequantize_per_channel_out(
187356
const Tensor& input,
188357
const Tensor& scale,
@@ -227,6 +396,20 @@ Tensor& dequantize_per_channel_out(
227396
check_dequantize_per_tensor_args(
228397
input, quant_min, quant_max, dtype, out_dtype, out);
229398

399+
if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
400+
dequantize_per_channel_optimized(
401+
input,
402+
scale,
403+
opt_zero_points,
404+
out,
405+
axis,
406+
quant_min,
407+
quant_max,
408+
dtype,
409+
out_dtype);
410+
return out;
411+
}
412+
230413
// a list contains all dimensions except axis
231414
int64_t dims[kTensorDimensionLimit];
232415
for (int64_t i = 0; i < input.dim() - 1; i++) {

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
122122
EXPECT_TENSOR_EQ(out, expected);
123123
}
124124

125-
TEST(OpDequantizeOutTest, DequantizePerChannel) {
126-
et_pal_init();
127-
TensorFactory<ScalarType::Byte> tf_byte;
125+
template <ScalarType DTYPE>
126+
void test_per_channel_dtype() {
127+
TensorFactory<DTYPE> tf;
128128
TensorFactory<ScalarType::Double> tf_double;
129129
TensorFactory<ScalarType::Long> tf_long;
130130

131-
Tensor input = tf_byte.full({3, 2}, 100);
131+
Tensor input = tf.full({3, 2}, 100);
132132
Tensor scale = tf_double.make({2}, {0.5, 1});
133133
Tensor zero_point = tf_long.make({2}, {30, 60});
134134
int64_t quant_min = 0;
@@ -146,7 +146,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
146146
/*axis=*/1,
147147
quant_min,
148148
quant_max,
149-
ScalarType::Byte,
149+
DTYPE,
150150
optional<ScalarType>(),
151151
out);
152152

@@ -167,15 +167,15 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
167167
/*axis=*/0,
168168
quant_min,
169169
quant_max,
170-
ScalarType::Byte,
170+
DTYPE,
171171
optional<ScalarType>(),
172172
out);
173173

174174
EXPECT_TENSOR_EQ(out, expected);
175175

176176
// Test with a different axis
177177
out = tfo.zeros({3});
178-
input = tf_byte.make({3}, {100, 100, 100});
178+
input = tf.make({3}, {100, 100, 100});
179179
scale = tf_double.make({3}, {0.5, 0.75, 1});
180180
zero_point = tf_long.make({3}, {30, 50, 60});
181181
// (100 - 30) * 0.5
@@ -189,8 +189,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
189189
/*axis=*/0,
190190
quant_min,
191191
quant_max,
192-
ScalarType::Byte,
192+
DTYPE,
193+
optional<ScalarType>(),
194+
out);
195+
EXPECT_TENSOR_EQ(out, expected);
196+
197+
// Test with a different axis
198+
input = tf.full({3, 19}, 100);
199+
out = tfo.zeros({3, 19});
200+
scale = tf_double.make({3}, {0.5, 0.75, 1});
201+
zero_point = tf_long.make({3}, {30, 50, 60});
202+
// (100 - 30) * 0.5
203+
// (100 - 50) * 0.75
204+
// (100 - 60) * 1
205+
expected = tfo.make(
206+
{3, 19},
207+
{35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
208+
35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5,
209+
37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5,
210+
37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
211+
40, 40, 40, 40, 40, 40, 40, 40, 40});
212+
dequantize_per_channel_out(
213+
input,
214+
scale,
215+
zero_point,
216+
/*axis=*/0,
217+
quant_min,
218+
quant_max,
219+
DTYPE,
193220
optional<ScalarType>(),
194221
out);
222+
195223
EXPECT_TENSOR_EQ(out, expected);
196224
}
225+
226+
TEST(OpDequantizeOutTest, DequantizePerChannel) {
227+
et_pal_init();
228+
test_per_channel_dtype<ScalarType::Byte>();
229+
test_per_channel_dtype<ScalarType::Char>();
230+
}

0 commit comments

Comments
 (0)