Skip to content

Commit eee643b

Browse files
seanx92meta-codesync[bot]
authored andcommitted
support scale_bias_last and quant_padding_float_type for cpu dequant kernel (#4943)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1963 Pull Request resolved: #4943 scale_bias_last: decides whether scale/bias padding is at the front or end of the row quant_padding_float_type: decides if scale/bias is represented by fp32 or fp16 this is to match the cuda kernel implementation functionalities and allow cpu dequantization with front padded FP16 scale/bias Reviewed By: q10 Differential Revision: D83405212 fbshipit-source-id: 34628568cb26dc66de24a9f02e9fb1161f20ace9
1 parent f75b360 commit eee643b

File tree

9 files changed

+271
-64
lines changed

9 files changed

+271
-64
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,19 @@ at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
469469
const int64_t output_dtype);
470470
at::Tensor& _fused8bitrowwise_to_float_cpu_out(
471471
at::Tensor& output,
472-
const at::Tensor& input);
472+
const at::Tensor& input,
473+
const bool scale_bias_last = true,
474+
const bool quant_padding_float_type = true);
475+
at::Tensor& fused8bitrowwise_to_half_cpu_out(
476+
at::Tensor& output,
477+
const at::Tensor& input,
478+
const bool scale_bias_last = true,
479+
const bool quant_padding_float_type = true);
473480
at::Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
474481
at::Tensor& output,
475-
const at::Tensor& input);
482+
const at::Tensor& input,
483+
const bool scale_bias_last = true,
484+
const bool quant_padding_float_type = true);
476485
at::Tensor& _float_to_fused8bitrowwise_cpu_out(
477486
at::Tensor& output,
478487
const at::Tensor& input);

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,26 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
5858
template <typename output_t, bool is_uint16_t_of_type_bf16 = false>
5959
Tensor& _fused8bitrowwise_to_float_cpu_out_t(
6060
Tensor& output,
61-
const Tensor& input) {
61+
const Tensor& input,
62+
const bool scale_bias_last,
63+
const bool quant_padding_float_type) {
6264
TENSOR_ON_CPU(input);
6365
TORCH_CHECK(
6466
input.dim() >= 2,
6567
"Tensor 'input' must have >= 2 dimension(s). Found ",
6668
input.ndimension());
69+
TORCH_CHECK(
70+
quant_padding_float_type == true || scale_bias_last == false,
71+
"2-byte padding (quant_padding_float_type=false) only works with scale_bias_last=false")
72+
73+
const int quant_padding_size =
74+
(quant_padding_float_type) ? sizeof(float) : sizeof(fbgemm::float16);
6775

6876
const auto input_sizes = input.sizes();
6977
const auto last_dim = input_sizes.size() - 1;
7078
const int64_t nrows = c10::size_to_dim_(last_dim, input_sizes);
7179
const int32_t ncols = input_sizes[last_dim];
72-
const int32_t output_columns = ncols - 2 * sizeof(float);
80+
const int32_t output_columns = ncols - 2 * quant_padding_size;
7381

7482
auto output_dims = input_sizes.vec();
7583
output_dims[last_dim] = output_columns;
@@ -81,7 +89,12 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
8189
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
8290
output_t,
8391
is_uint16_t_of_type_bf16>(
84-
input.data_ptr<uint8_t>(), nrows, ncols, output_data);
92+
input.data_ptr<uint8_t>(),
93+
nrows,
94+
ncols,
95+
output_data,
96+
scale_bias_last,
97+
quant_padding_float_type);
8598

8699
return output;
87100
}
@@ -218,20 +231,29 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
218231
///
219232
Tensor& _fused8bitrowwise_to_float_cpu_out(
220233
Tensor& output,
221-
const Tensor& input) {
222-
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(output, input);
234+
const Tensor& input,
235+
const bool scale_bias_last,
236+
const bool quant_padding_float_type) {
237+
return _fused8bitrowwise_to_float_cpu_out_t<float, false>(
238+
output, input, scale_bias_last, quant_padding_float_type);
223239
}
224240

225-
Tensor& fused8bitrowwise_to_half_cpu_out(Tensor& output, const Tensor& input) {
241+
Tensor& fused8bitrowwise_to_half_cpu_out(
242+
Tensor& output,
243+
const Tensor& input,
244+
const bool scale_bias_last,
245+
const bool quant_padding_float_type) {
226246
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::float16, false>(
227-
output, input);
247+
output, input, scale_bias_last, quant_padding_float_type);
228248
}
229249

230250
Tensor& _fused8bitrowwise_to_bfloat16_cpu_out(
231251
Tensor& output,
232-
const Tensor& input) {
252+
const Tensor& input,
253+
const bool scale_bias_last,
254+
const bool quant_padding_float_type) {
233255
return _fused8bitrowwise_to_float_cpu_out_t<fbgemm::bfloat16, true>(
234-
output, input);
256+
output, input, scale_bias_last, quant_padding_float_type);
235257
}
236258

237259
/// @ingroup quantize-data-cpu
@@ -307,24 +329,27 @@ Tensor fused8bitrowwise_to_bfloat16_cpu(const Tensor& input) {
307329
Tensor fused8bitrowwise_to_float_or_half_cpu(
308330
const Tensor& input,
309331
const int64_t output_dtype,
310-
[[maybe_unused]] const bool scale_bias_last,
311-
[[maybe_unused]] const bool quant_padding_float_type) {
332+
const bool scale_bias_last,
333+
const bool quant_padding_float_type) {
312334
Tensor output;
313335
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
314336
switch (output_sparse_dtype) {
315337
case SparseType::FP32:
316338
output = at::empty({0}, input.options().dtype(at::kFloat));
317339

318-
output = _fused8bitrowwise_to_float_cpu_out(output, input);
340+
output = _fused8bitrowwise_to_float_cpu_out(
341+
output, input, scale_bias_last, quant_padding_float_type);
319342

320343
break;
321344
case SparseType::FP16:
322345
output = at::empty({0}, input.options().dtype(at::kHalf));
323-
output = fused8bitrowwise_to_half_cpu_out(output, input);
346+
output = fused8bitrowwise_to_half_cpu_out(
347+
output, input, scale_bias_last, quant_padding_float_type);
324348
break;
325349
case SparseType::BF16:
326350
output = at::empty({0}, input.options().dtype(at::kBFloat16));
327-
output = _fused8bitrowwise_to_bfloat16_cpu_out(output, input);
351+
output = _fused8bitrowwise_to_bfloat16_cpu_out(
352+
output, input, scale_bias_last, quant_padding_float_type);
328353
break;
329354
default:
330355
TORCH_CHECK(false);
@@ -607,7 +632,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
607632
m.def(
608633
"Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor");
609634
m.def(
610-
"Fused8BitRowwiseQuantizedToFloatOut(Tensor output, Tensor input) -> Tensor");
635+
"Fused8BitRowwiseQuantizedToFloatOut(Tensor output, Tensor input, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor");
611636
m.def(
612637
"Fused8BitRowwiseQuantizedToFloatMixedDim(Tensor input, Tensor D_offsets, int output_dtype) -> Tensor");
613638
m.def(

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,37 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
144144
# cpu path only supports bf16 dequantization
145145
if output_dtype == SparseType.BF16:
146146
input_data = input_data.float()
147+
if not test_generic_op and not quant_padding_float_type:
148+
return
149+
if not quant_padding_float_type and output_dtype == SparseType.FP32:
150+
return
147151
if test_generic_op:
148-
quantized_data = (
152+
quantized_data_ref = (
149153
torch.ops.fbgemm.FloatOrHalfToFused8BitRowwiseQuantized(input_data)
150154
)
155+
# fbgemm weight 2byte storages are scale_bias first layout
156+
if quant_padding_float_type is False:
157+
scale_bias_last = False
158+
quant_pad = quantized_data_ref[:, -8:]
159+
quant_data = quantized_data_ref[:, :-8]
160+
quantized_data = torch.cat(
161+
[
162+
quant_pad.view(torch.float)
163+
.to(torch.half)
164+
.view(torch.uint8),
165+
quant_data,
166+
],
167+
dim=1,
168+
)
169+
else:
170+
scale_bias_last = True
171+
quantized_data = quantized_data_ref
151172
dequantized_data = (
152173
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatOrHalf(
153174
quantized_data,
154175
output_dtype.as_int(),
176+
quant_padding_float_type=quant_padding_float_type,
177+
scale_bias_last=scale_bias_last,
155178
)
156179
)
157180
else:
@@ -187,9 +210,17 @@ def quantize_and_dequantize_op_test_helper( # noqa: C901
187210
assert dequantized_data.numel() == 0
188211
return
189212

190-
reference = torch.from_numpy(
191-
fused_rowwise_8bit_dequantize_reference(quantized_data.numpy())
192-
)
213+
quantize_data_numpy = quantized_data.numpy()
214+
if quant_padding_float_type:
215+
reference = torch.from_numpy(
216+
fused_rowwise_8bit_dequantize_reference(quantize_data_numpy)
217+
)
218+
else:
219+
reference = torch.from_numpy(
220+
fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference(
221+
quantize_data_numpy
222+
)
223+
)
193224
if output_dtype == SparseType.FP32:
194225
torch.testing.assert_close(dequantized_data.float(), reference.float())
195226
elif output_dtype == SparseType.FP16:

include/fbgemm/QuantUtils.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
336336
const uint8_t* input,
337337
size_t input_rows,
338338
int input_columns,
339-
OutputType* output);
339+
OutputType* output,
340+
const bool scale_bias_last = true,
341+
const bool quant_padding_float_type = true);
340342

341343
/**
342344
* Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized.
@@ -383,6 +385,8 @@ FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
383385
const uint8_t* input,
384386
size_t input_rows,
385387
int input_columns,
386-
OutputType* output);
388+
OutputType* output,
389+
const bool scale_bias_last = true,
390+
const bool quant_padding_float_type = true);
387391

388392
} // namespace fbgemm

include/fbgemm/QuantUtilsAvx2.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
166166
int input_columns,
167167
OutputType* output);
168168

169-
template <typename OutputType>
169+
template <
170+
typename OutputType,
171+
bool scale_bias_last = true,
172+
bool quant_padding_float_type = true>
170173
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
171174
const std::uint8_t* input,
172175
size_t input_rows,

include/fbgemm/QuantUtilsAvx512.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ FBGEMM_API void requantizeOutputProcessingGConvAvx512(
3939
int ld_in,
4040
const requantizationParams_t<BIAS_TYPE>& r);
4141

42+
template <bool scale_bias_last = true, bool quant_padding_float_type = true>
4243
void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512(
4344
const std::uint8_t* input,
4445
size_t input_rows,

0 commit comments

Comments
 (0)