@@ -87,7 +87,8 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
87
87
template <typename input_t >
88
88
Tensor _float_to_fusednbitrowwise_cpu (
89
89
const Tensor& input,
90
- const int64_t bit_rate) {
90
+ const int64_t bit_rate,
91
+ const input_t * rowwise_min_max = nullptr ) {
91
92
TENSOR_ON_CPU (input);
92
93
TENSOR_NDIM_EQUALS (input, 2 );
93
94
@@ -109,7 +110,12 @@ Tensor _float_to_fusednbitrowwise_cpu(
109
110
input.data_ptr ()); // input.data_ptr<input_t>(); -> Yields
110
111
// unresolved data_ptr symbol.
111
112
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<input_t >(
112
- bit_rate, input_data, nrows, ncols, output.data_ptr <uint8_t >());
113
+ bit_rate,
114
+ input_data,
115
+ nrows,
116
+ ncols,
117
+ output.data_ptr <uint8_t >(),
118
+ rowwise_min_max);
113
119
114
120
return output;
115
121
}
@@ -427,6 +433,37 @@ Tensor float_or_half_to_fusednbitrowwise_cpu(
427
433
return output;
428
434
}
429
435
436
+ static Tensor float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max (
437
+ const Tensor& input,
438
+ const int64_t bit_rate,
439
+ const Tensor& rowwise_min_max) {
440
+ TORCH_CHECK (
441
+ (rowwise_min_max.dim () == 2 && rowwise_min_max.size (0 ) == input.size (0 ) &&
442
+ rowwise_min_max.size (1 ) == fbgemm::kRowwiseMinMaxNumCols ),
443
+ " 'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2]." );
444
+
445
+ const auto rowwise_min_max_contig = rowwise_min_max.expect_contiguous (
446
+ rowwise_min_max.suggest_memory_format ());
447
+ Tensor output;
448
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
449
+ input.scalar_type (),
450
+ " float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max" ,
451
+ [&] {
452
+ if constexpr (std::is_same_v<scalar_t , float >) {
453
+ const auto rowwise_min_max_data =
454
+ rowwise_min_max_contig->data_ptr <float >();
455
+ output = _float_to_fusednbitrowwise_cpu<float >(
456
+ input, bit_rate, rowwise_min_max_data);
457
+ } else { // scalar_t = at::Half
458
+ const auto rowwise_min_max_data =
459
+ static_cast <fbgemm::float16*>(rowwise_min_max_contig->data_ptr ());
460
+ output = _float_to_fusednbitrowwise_cpu<fbgemm::float16>(
461
+ input, bit_rate, rowwise_min_max_data);
462
+ }
463
+ });
464
+ return output;
465
+ }
466
+
430
467
// / @ingroup quantize-data-cpu
431
468
// /
432
469
void FloatToFP8Quantized_ref (
@@ -557,6 +594,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
557
594
" HalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor" );
558
595
m.def (
559
596
" FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor" );
597
+ m.def (
598
+ " FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfWithRowwiseMinMax(Tensor input, int bit_rate, Tensor rowwise_min_max) -> Tensor" );
560
599
m.def (
561
600
" FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor" );
562
601
m.def (
@@ -624,6 +663,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
624
663
DISPATCH_TO_CPU (
625
664
" FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf" ,
626
665
fbgemm_gpu::float_or_half_to_fusednbitrowwise_cpu);
666
+ DISPATCH_TO_CPU (
667
+ " FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfWithRowwiseMinMax" ,
668
+ fbgemm_gpu::float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max);
627
669
DISPATCH_TO_CPU (
628
670
" FusedNBitRowwiseQuantizedSBHalfToFloat" ,
629
671
fbgemm_gpu::fusednbitrowwise_to_float_cpu);
0 commit comments