Skip to content

Commit e699b6f

Browse files
sampathvicfacebook-github-bot
authored andcommitted
Quantization with min & max bounds support - Adding CPU ops for n-bit quantizations (pytorch#4860)
Summary: Pull Request resolved: pytorch#4860 X-link: facebookresearch/FBGEMM#1883 This diff creates the quantize CPU ops for the n-bit quantization with the row-wise min/max added in D81858256. Reviewed By: excelle08 Differential Revision: D82150550 fbshipit-source-id: 6a25632f4132e2884b314cc3d102b18d7c848695
1 parent 3b0269b commit e699b6f

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ Tensor& _fused8bitrowwise_to_float_cpu_out_t(
8787
template <typename input_t>
8888
Tensor _float_to_fusednbitrowwise_cpu(
8989
const Tensor& input,
90-
const int64_t bit_rate) {
90+
const int64_t bit_rate,
91+
const input_t* rowwise_min_max = nullptr) {
9192
TENSOR_ON_CPU(input);
9293
TENSOR_NDIM_EQUALS(input, 2);
9394

@@ -109,7 +110,12 @@ Tensor _float_to_fusednbitrowwise_cpu(
109110
input.data_ptr()); // input.data_ptr<input_t>(); -> Yields
110111
// unresolved data_ptr symbol.
111112
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<input_t>(
112-
bit_rate, input_data, nrows, ncols, output.data_ptr<uint8_t>());
113+
bit_rate,
114+
input_data,
115+
nrows,
116+
ncols,
117+
output.data_ptr<uint8_t>(),
118+
rowwise_min_max);
113119

114120
return output;
115121
}
@@ -427,6 +433,37 @@ Tensor float_or_half_to_fusednbitrowwise_cpu(
427433
return output;
428434
}
429435

436+
static Tensor float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max(
437+
const Tensor& input,
438+
const int64_t bit_rate,
439+
const Tensor& rowwise_min_max) {
440+
TORCH_CHECK(
441+
(rowwise_min_max.dim() == 2 && rowwise_min_max.size(0) == input.size(0) &&
442+
rowwise_min_max.size(1) == fbgemm::kRowwiseMinMaxNumCols),
443+
"'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2].");
444+
445+
const auto rowwise_min_max_contig = rowwise_min_max.expect_contiguous(
446+
rowwise_min_max.suggest_memory_format());
447+
Tensor output;
448+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
449+
input.scalar_type(),
450+
"float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max",
451+
[&] {
452+
if constexpr (std::is_same_v<scalar_t, float>) {
453+
const auto rowwise_min_max_data =
454+
rowwise_min_max_contig->data_ptr<float>();
455+
output = _float_to_fusednbitrowwise_cpu<float>(
456+
input, bit_rate, rowwise_min_max_data);
457+
} else { // scalar_t = at::Half
458+
const auto rowwise_min_max_data =
459+
static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr());
460+
output = _float_to_fusednbitrowwise_cpu<fbgemm::float16>(
461+
input, bit_rate, rowwise_min_max_data);
462+
}
463+
});
464+
return output;
465+
}
466+
430467
/// @ingroup quantize-data-cpu
431468
///
432469
void FloatToFP8Quantized_ref(
@@ -557,6 +594,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
557594
"HalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor");
558595
m.def(
559596
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor");
597+
m.def(
598+
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfWithRowwiseMinMax(Tensor input, int bit_rate, Tensor rowwise_min_max) -> Tensor");
560599
m.def(
561600
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
562601
m.def(
@@ -624,6 +663,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
624663
DISPATCH_TO_CPU(
625664
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
626665
fbgemm_gpu::float_or_half_to_fusednbitrowwise_cpu);
666+
DISPATCH_TO_CPU(
667+
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfWithRowwiseMinMax",
668+
fbgemm_gpu::float_or_half_to_fusednbitrowwise_cpu_with_rowwise_min_max);
627669
DISPATCH_TO_CPU(
628670
"FusedNBitRowwiseQuantizedSBHalfToFloat",
629671
fbgemm_gpu::fusednbitrowwise_to_float_cpu);

0 commit comments

Comments
 (0)