@@ -58,18 +58,26 @@ Tensor& _float_to_fused8bitrowwise_cpu_out_t(
58
58
template <typename output_t , bool is_uint16_t_of_type_bf16 = false >
59
59
Tensor& _fused8bitrowwise_to_float_cpu_out_t (
60
60
Tensor& output,
61
- const Tensor& input) {
61
+ const Tensor& input,
62
+ const bool scale_bias_last,
63
+ const bool quant_padding_float_type) {
62
64
TENSOR_ON_CPU (input);
63
65
TORCH_CHECK (
64
66
input.dim () >= 2 ,
65
67
" Tensor 'input' must have >= 2 dimension(s). Found " ,
66
68
input.ndimension ());
69
+ TORCH_CHECK (
70
+ quant_padding_float_type == true || scale_bias_last == false ,
71
+ " 2-byte padding (quant_padding_float_type=false) only works with scale_bias_last=false" )
72
+
73
+ const int quant_padding_size =
74
+ (quant_padding_float_type) ? sizeof (float ) : sizeof (fbgemm::float16);
67
75
68
76
const auto input_sizes = input.sizes ();
69
77
const auto last_dim = input_sizes.size () - 1 ;
70
78
const int64_t nrows = c10::size_to_dim_ (last_dim, input_sizes);
71
79
const int32_t ncols = input_sizes[last_dim];
72
- const int32_t output_columns = ncols - 2 * sizeof ( float ) ;
80
+ const int32_t output_columns = ncols - 2 * quant_padding_size ;
73
81
74
82
auto output_dims = input_sizes.vec ();
75
83
output_dims[last_dim] = output_columns;
@@ -81,7 +89,12 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
81
89
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<
82
90
output_t ,
83
91
is_uint16_t_of_type_bf16>(
84
- input.data_ptr <uint8_t >(), nrows, ncols, output_data);
92
+ input.data_ptr <uint8_t >(),
93
+ nrows,
94
+ ncols,
95
+ output_data,
96
+ scale_bias_last,
97
+ quant_padding_float_type);
85
98
86
99
return output;
87
100
}
@@ -218,20 +231,29 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
218
231
// /
219
232
Tensor& _fused8bitrowwise_to_float_cpu_out (
220
233
Tensor& output,
221
- const Tensor& input) {
222
- return _fused8bitrowwise_to_float_cpu_out_t <float , false >(output, input);
234
+ const Tensor& input,
235
+ const bool scale_bias_last,
236
+ const bool quant_padding_float_type) {
237
+ return _fused8bitrowwise_to_float_cpu_out_t <float , false >(
238
+ output, input, scale_bias_last, quant_padding_float_type);
223
239
}
224
240
225
- Tensor& fused8bitrowwise_to_half_cpu_out (Tensor& output, const Tensor& input) {
241
+ Tensor& fused8bitrowwise_to_half_cpu_out (
242
+ Tensor& output,
243
+ const Tensor& input,
244
+ const bool scale_bias_last,
245
+ const bool quant_padding_float_type) {
226
246
return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::float16, false >(
227
- output, input);
247
+ output, input, scale_bias_last, quant_padding_float_type );
228
248
}
229
249
230
250
Tensor& _fused8bitrowwise_to_bfloat16_cpu_out (
231
251
Tensor& output,
232
- const Tensor& input) {
252
+ const Tensor& input,
253
+ const bool scale_bias_last,
254
+ const bool quant_padding_float_type) {
233
255
return _fused8bitrowwise_to_float_cpu_out_t <fbgemm::bfloat16, true >(
234
- output, input);
256
+ output, input, scale_bias_last, quant_padding_float_type );
235
257
}
236
258
237
259
// / @ingroup quantize-data-cpu
@@ -307,24 +329,27 @@ Tensor fused8bitrowwise_to_bfloat16_cpu(const Tensor& input) {
307
329
Tensor fused8bitrowwise_to_float_or_half_cpu (
308
330
const Tensor& input,
309
331
const int64_t output_dtype,
310
- [[maybe_unused]] const bool scale_bias_last,
311
- [[maybe_unused]] const bool quant_padding_float_type) {
332
+ const bool scale_bias_last,
333
+ const bool quant_padding_float_type) {
312
334
Tensor output;
313
335
SparseType output_sparse_dtype = static_cast <SparseType>(output_dtype);
314
336
switch (output_sparse_dtype) {
315
337
case SparseType::FP32:
316
338
output = at::empty ({0 }, input.options ().dtype (at::kFloat ));
317
339
318
- output = _fused8bitrowwise_to_float_cpu_out (output, input);
340
+ output = _fused8bitrowwise_to_float_cpu_out (
341
+ output, input, scale_bias_last, quant_padding_float_type);
319
342
320
343
break ;
321
344
case SparseType::FP16:
322
345
output = at::empty ({0 }, input.options ().dtype (at::kHalf ));
323
- output = fused8bitrowwise_to_half_cpu_out (output, input);
346
+ output = fused8bitrowwise_to_half_cpu_out (
347
+ output, input, scale_bias_last, quant_padding_float_type);
324
348
break ;
325
349
case SparseType::BF16:
326
350
output = at::empty ({0 }, input.options ().dtype (at::kBFloat16 ));
327
- output = _fused8bitrowwise_to_bfloat16_cpu_out (output, input);
351
+ output = _fused8bitrowwise_to_bfloat16_cpu_out (
352
+ output, input, scale_bias_last, quant_padding_float_type);
328
353
break ;
329
354
default :
330
355
TORCH_CHECK (false );
@@ -607,7 +632,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
607
632
m.def (
608
633
" Fused8BitRowwiseQuantizedToFloatOrHalf(Tensor input, int output_dtype=0, bool scale_bias_last=True, bool quant_padding_float_type=True) -> Tensor" );
609
634
m.def (
610
- " Fused8BitRowwiseQuantizedToFloatOut(Tensor output, Tensor input) -> Tensor" );
635
+ " Fused8BitRowwiseQuantizedToFloatOut(Tensor output, Tensor input, bool scale_bias_last=True, bool quant_padding_float_type=True ) -> Tensor" );
611
636
m.def (
612
637
" Fused8BitRowwiseQuantizedToFloatMixedDim(Tensor input, Tensor D_offsets, int output_dtype) -> Tensor" );
613
638
m.def (
0 commit comments