Skip to content

Commit 513902e

Browse files
committed
Update on "Add update_quantized_cache op"
Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path Differential Revision: [D62301838](https://our.internmc.facebook.com/intern/diff/D62301838/) [ghstack-poisoned]
2 parents 2c67a95 + b051652 commit 513902e

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ Tensor& dequantize_per_tensor_tensor_args_out(
168168
return out;
169169
}
170170

171+
float get_scale(const Tensor& scale, size_t channel_ix) {
172+
ET_CHECK_MSG(
173+
(scale.scalar_type() == ScalarType::Double) ||
174+
(scale.scalar_type() == ScalarType::Float),
175+
"scale.scalar_type() %" PRId8 " is not double or float type",
176+
static_cast<int8_t>(scale.scalar_type()));
177+
if (scale.scalar_type() == ScalarType::Double) {
178+
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
179+
} else {
180+
return scale.const_data_ptr<float>()[channel_ix];
181+
}
182+
}
183+
171184
Tensor& dequantize_per_channel_out(
172185
const Tensor& input,
173186
const Tensor& scale,
@@ -189,11 +202,6 @@ Tensor& dequantize_per_channel_out(
189202
axis += nonzero_dim(input);
190203
}
191204

192-
ET_CHECK_MSG(
193-
scale.scalar_type() == ScalarType::Double,
194-
"scale.scalar_type() %" PRId8 " is not double type",
195-
static_cast<int8_t>(scale.scalar_type()));
196-
197205
ET_CHECK_MSG(
198206
scale.numel() == input.size(axis),
199207
"scale.numel() %zd != input.size(axis) %zd",
@@ -226,7 +234,6 @@ Tensor& dequantize_per_channel_out(
226234
dims[i] = i + 1;
227235
}
228236
}
229-
const double* scale_data = scale.const_data_ptr<double>();
230237
const int64_t* zero_point_data;
231238
if (opt_zero_points.has_value()) {
232239
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
@@ -254,11 +261,11 @@ Tensor& dequantize_per_channel_out(
254261
axis == 0, "Axis must be 0 for a single dimensional tensors"); \
255262
const optional<int64_t> dim; \
256263
apply_over_dim( \
257-
[input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \
264+
[input_data_ptr, out_data_ptr, zero_point_data, &scale]( \
258265
size_t numel, size_t stride, size_t base_ix) { \
259266
for (size_t i = 0; i < numel; i++) { \
260267
size_t current_ix = base_ix * stride + i; \
261-
float _scale = static_cast<float>(scale_data[current_ix]); \
268+
float _scale = get_scale(scale, current_ix); \
262269
int64_t zero_point = 0; \
263270
if (zero_point_data != nullptr) { \
264271
zero_point = zero_point_data[current_ix]; \
@@ -274,7 +281,7 @@ Tensor& dequantize_per_channel_out(
274281
break; \
275282
} \
276283
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
277-
float _scale = static_cast<float>(scale_data[channel_ix]); \
284+
float _scale = get_scale(scale, channel_ix); \
278285
int64_t _zero_point = 0; \
279286
if (zero_point_data != nullptr) { \
280287
_zero_point = zero_point_data[channel_ix]; \

0 commit comments

Comments
 (0)