Skip to content

Commit 0d30f0d

Browse files
sampathvicfacebook-github-bot
authored andcommitted
Quantization with min & max bounds support - 4-bit & 2-bit on X86-64 (#4833)
Summary: Pull Request resolved: #4833 X-link: facebookresearch/FBGEMM#1860 In D78181177 we have added support for row-wise min/max for 8-bit quantization. In this diff similar thing is done for 4-bit & 2-bit quantization as well. Reviewed By: spcyppt, excelle08 Differential Revision: D81858256 fbshipit-source-id: b6fbb5914b3a832143ebacf34dfaf2573d74fd85
1 parent c760f8f commit 0d30f0d

File tree

5 files changed

+206
-44
lines changed

5 files changed

+206
-44
lines changed

include/fbgemm/QuantUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
285285
const InputType* input,
286286
size_t input_rows,
287287
int input_columns,
288-
std::uint8_t* output);
288+
std::uint8_t* output,
289+
const InputType* rowwise_min_max = nullptr);
289290

290291
/**
291292
* Convert fused rowwise quantized inputs to float (fp32 or fp16).

include/fbgemm/QuantUtilsAvx2.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
namespace fbgemm {
1919

20+
/// Number of columns in the rowwise min/max buffer passed to the quantization
21+
/// function(s)
22+
constexpr int kRowwiseMinMaxNumCols = 2;
23+
2024
/// Struct from <a href="https://github.com/google/gemmlowp">`gemmlowp`</a>
2125
///
2226
/// A structure to hold quantization parameters `scale` and `zero_point`.
@@ -144,7 +148,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
144148
const InputType* input,
145149
size_t input_rows,
146150
int input_columns,
147-
std::uint8_t* output);
151+
std::uint8_t* output,
152+
const InputType* rowwise_min_max = nullptr);
148153

149154
template <typename InputType>
150155
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(

src/QuantUtils.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
626626
const InputType* input,
627627
size_t input_rows,
628628
int input_columns,
629-
std::uint8_t* output) {
629+
std::uint8_t* output,
630+
const InputType* rowwise_min_max) {
630631
// Currenlty we can only dequantize if the number of input columns
631632
// is a multiple of number of elements_per_byte
632633

@@ -640,15 +641,15 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
640641
switch (bit_rate) {
641642
case 2:
642643
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 2>(
643-
input, input_rows, input_columns, output);
644+
input, input_rows, input_columns, output, rowwise_min_max);
644645
break;
645646
case 4:
646647
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 4>(
647-
input, input_rows, input_columns, output);
648+
input, input_rows, input_columns, output, rowwise_min_max);
648649
break;
649650
case 8:
650651
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 8>(
651-
input, input_rows, input_columns, output);
652+
input, input_rows, input_columns, output, rowwise_min_max);
652653
break;
653654
default:
654655
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
@@ -866,7 +867,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
866867
const type* input, \
867868
size_t input_rows, \
868869
int input_columns, \
869-
std::uint8_t* output); \
870+
std::uint8_t* output, \
871+
const type* rowwise_min_max); \
870872
template FBGEMM_API void \
871873
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type, false>( \
872874
int bit_rate, \

src/QuantUtilsAvx2.cc

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
15631563
const InputType* input,
15641564
size_t input_rows,
15651565
int input_columns,
1566-
std::uint8_t* output) {
1566+
std::uint8_t* output,
1567+
const InputType* rowwise_min_max) {
15671568
static_assert(
15681569
std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
15691570
"Only float and float16 types are allowed.");
@@ -1574,6 +1575,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
15741575
2 * sizeof(std::uint16_t);
15751576

15761577
float* input_row_float_for_fp16 = nullptr;
1578+
float min_max_row_float_for_fp16[kRowwiseMinMaxNumCols];
1579+
const auto is_valid_rowwise_min_max = (rowwise_min_max != nullptr);
15771580
if constexpr (std::is_same<InputType, float16>()) {
15781581
input_row_float_for_fp16 = static_cast<float*>(
15791582
fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
@@ -1591,48 +1594,72 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
15911594
input_row_float = input_row_float_for_fp16;
15921595
}
15931596

1597+
const float* min_max_row_float = nullptr;
1598+
if (is_valid_rowwise_min_max) {
1599+
const InputType* min_max_row =
1600+
rowwise_min_max + row * kRowwiseMinMaxNumCols;
1601+
1602+
if constexpr (std::is_same_v<InputType, float>) {
1603+
min_max_row_float = reinterpret_cast<const float*>(min_max_row);
1604+
} else {
1605+
min_max_row_float_for_fp16[0] = halfToFloat(min_max_row[0]);
1606+
min_max_row_float_for_fp16[1] = halfToFloat(min_max_row[1]);
1607+
min_max_row_float = min_max_row_float_for_fp16;
1608+
}
1609+
}
1610+
15941611
std::uint8_t* output_row = output + row * output_columns;
15951612
std::uint16_t* output_row_scale_bias = reinterpret_cast<std::uint16_t*>(
15961613
output_row +
15971614
(input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
15981615

15991616
float minimum_element = FLT_MAX;
16001617
float maximum_element = -FLT_MAX;
1601-
__m256 min_v = _mm256_set1_ps(minimum_element);
1602-
__m256 max_v = _mm256_set1_ps(maximum_element);
1618+
if (is_valid_rowwise_min_max) {
1619+
minimum_element = min_max_row_float[0];
1620+
maximum_element = min_max_row_float[1];
16031621

1604-
int col = 0;
1605-
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
1606-
__m256 in_v;
1607-
if constexpr (std::is_same<InputType, float>()) {
1608-
in_v = _mm256_loadu_ps(input_row_float + col);
1609-
} else {
1610-
__m128i in_half_v =
1611-
_mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col));
1612-
in_v = _mm256_cvtph_ps(in_half_v);
1613-
_mm256_store_ps(input_row_float_for_fp16 + col, in_v);
1622+
for (int col = 0; col < input_columns; ++col) {
1623+
if constexpr (std::is_same<InputType, float16>()) {
1624+
input_row_float_for_fp16[col] = halfToFloat(input_row[col]);
1625+
}
16141626
}
1627+
} else {
1628+
__m256 min_v = _mm256_set1_ps(minimum_element);
1629+
__m256 max_v = _mm256_set1_ps(maximum_element);
1630+
int col = 0;
1631+
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
1632+
__m256 in_v;
1633+
if constexpr (std::is_same<InputType, float>()) {
1634+
in_v = _mm256_loadu_ps(input_row_float + col);
1635+
} else {
1636+
__m128i in_half_v = _mm_loadu_si128(
1637+
reinterpret_cast<const __m128i*>(input_row + col));
1638+
in_v = _mm256_cvtph_ps(in_half_v);
1639+
_mm256_store_ps(input_row_float_for_fp16 + col, in_v);
1640+
}
16151641

1616-
min_v = _mm256_min_ps(min_v, in_v);
1617-
max_v = _mm256_max_ps(max_v, in_v);
1618-
}
1619-
alignas(64) float min_buf[VLEN], max_buf[VLEN];
1620-
_mm256_store_ps(min_buf, min_v);
1621-
_mm256_store_ps(max_buf, max_v);
1622-
for (int i = 0; i < VLEN; ++i) {
1623-
minimum_element = std::min(minimum_element, min_buf[i]);
1624-
maximum_element = std::max(maximum_element, max_buf[i]);
1625-
}
1642+
min_v = _mm256_min_ps(min_v, in_v);
1643+
max_v = _mm256_max_ps(max_v, in_v);
1644+
}
1645+
alignas(64) float min_buf[VLEN], max_buf[VLEN];
1646+
_mm256_store_ps(min_buf, min_v);
1647+
_mm256_store_ps(max_buf, max_v);
1648+
for (int i = 0; i < VLEN; ++i) {
1649+
minimum_element = std::min(minimum_element, min_buf[i]);
1650+
maximum_element = std::max(maximum_element, max_buf[i]);
1651+
}
16261652

1627-
for (; col < input_columns; ++col) {
1628-
if constexpr (std::is_same<InputType, float>()) {
1629-
minimum_element = std::min(minimum_element, input_row_float[col]);
1630-
maximum_element = std::max(maximum_element, input_row_float[col]);
1631-
} else {
1632-
float element = halfToFloat(input_row[col]);
1633-
input_row_float_for_fp16[col] = element;
1634-
minimum_element = std::min(minimum_element, element);
1635-
maximum_element = std::max(maximum_element, element);
1653+
for (; col < input_columns; ++col) {
1654+
if constexpr (std::is_same<InputType, float>()) {
1655+
minimum_element = std::min(minimum_element, input_row_float[col]);
1656+
maximum_element = std::max(maximum_element, input_row_float[col]);
1657+
} else {
1658+
float element = halfToFloat(input_row[col]);
1659+
input_row_float_for_fp16[col] = element;
1660+
minimum_element = std::min(minimum_element, element);
1661+
maximum_element = std::max(maximum_element, element);
1662+
}
16361663
}
16371664
}
16381665

@@ -1657,12 +1684,12 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
16571684

16581685
output_row_scale_bias[0] = floatToHalf(scale);
16591686

1660-
col = 0;
1687+
int col = 0;
16611688
if constexpr (BIT_RATE == 2 || BIT_RATE == 4) {
16621689
__m256i permute_mask1_v =
16631690
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
16641691
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
1665-
min_v = _mm256_set1_ps(minimum_element);
1692+
__m256 min_v = _mm256_set1_ps(minimum_element);
16661693

16671694
for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) {
16681695
__m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
@@ -1778,7 +1805,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
17781805

17791806
const int64_t output_columns = input_columns + 2 * sizeof(float);
17801807
float* input_row_float_for_fp16 = nullptr;
1781-
float min_max_row_float_for_fp16[2];
1808+
float min_max_row_float_for_fp16[kRowwiseMinMaxNumCols];
17821809
const auto is_valid_rowwise_min_max = (rowwise_min_max != nullptr);
17831810
if constexpr (std::is_same_v<InputType, float16>) {
17841811
input_row_float_for_fp16 = static_cast<float*>(
@@ -1798,7 +1825,8 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
17981825

17991826
const float* min_max_row_float = nullptr;
18001827
if (is_valid_rowwise_min_max) {
1801-
const InputType* min_max_row = rowwise_min_max + row * 2;
1828+
const InputType* min_max_row =
1829+
rowwise_min_max + row * kRowwiseMinMaxNumCols;
18021830

18031831
if constexpr (std::is_same_v<InputType, float>) {
18041832
min_max_row_float = reinterpret_cast<const float*>(min_max_row);
@@ -2205,7 +2233,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
22052233
const type* input, \
22062234
size_t input_rows, \
22072235
int input_columns, \
2208-
std::uint8_t* output); \
2236+
std::uint8_t* output, \
2237+
const type* rowwise_min_max); \
22092238
template void \
22102239
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<type, bit_rate>( \
22112240
const std::uint8_t* input, \

test/QuantUtilsTest.cc

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,14 @@ class EmbeddingQuantizeFixedNumberTest : public testing::TestWithParam<int> {
594594
1, 1,
595595
-64, 191,
596596
};
597+
598+
float16_test_input_rowwise_min_max.resize(
599+
float_test_input_rowwise_min_max.size());
600+
std::transform(
601+
float_test_input_rowwise_min_max.begin(),
602+
float_test_input_rowwise_min_max.end(),
603+
float16_test_input_rowwise_min_max.begin(),
604+
[](float input) { return cpu_float2half_rn(input); });
597605
}
598606
// clang-format on
599607

@@ -607,6 +615,7 @@ class EmbeddingQuantizeFixedNumberTest : public testing::TestWithParam<int> {
607615
expected_output_half;
608616
std::vector<uint8_t> expected_output_float;
609617
std::vector<float> float_test_input_rowwise_min_max;
618+
std::vector<float16> float16_test_input_rowwise_min_max;
610619
};
611620

612621
INSTANTIATE_TEST_SUITE_P(
@@ -861,3 +870,119 @@ TEST_P(
861870
col));
862871
#endif
863872
}
873+
874+
TEST_P(
875+
EmbeddingQuantizeFixedNumberTest,
876+
embeddingFloatOrHalfToFusedNBitRowwiseQuantizedSBHalfTest) {
877+
const int bit_rate = GetParam();
878+
879+
// Confirm that quantization with rowwise_min_max produces expected results.
880+
vector<uint8_t> outVectFloatTest(row * out_cols_half);
881+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
882+
bit_rate,
883+
float_test_input.data(),
884+
row,
885+
col,
886+
outVectFloatTest.data(),
887+
float_test_input_rowwise_min_max.data());
888+
EXPECT_TRUE(isQEmbeddingClose<float16>(
889+
expected_output_half[bit_rate], outVectFloatTest, row, col));
890+
891+
vector<uint8_t> outVectHalfTest(row * out_cols_half);
892+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float16>(
893+
bit_rate,
894+
float16_test_input.data(),
895+
row,
896+
col,
897+
outVectHalfTest.data(),
898+
float16_test_input_rowwise_min_max.data());
899+
EXPECT_TRUE(isQEmbeddingClose<float16>(
900+
expected_output_half[bit_rate], outVectHalfTest, row, col));
901+
902+
// Confirm that quantization with and without rowwise_min_max produces
903+
// similar results.
904+
vector<uint8_t> outVecFloatTestNoRowwiseMinMax(row * out_cols_half);
905+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
906+
bit_rate,
907+
float_test_input.data(),
908+
row,
909+
col,
910+
outVecFloatTestNoRowwiseMinMax.data());
911+
EXPECT_TRUE(isQEmbeddingClose<float16>(
912+
expected_output_half[bit_rate],
913+
outVecFloatTestNoRowwiseMinMax,
914+
row,
915+
col));
916+
EXPECT_TRUE(isQEmbeddingClose<float16>(
917+
outVectFloatTest, outVecFloatTestNoRowwiseMinMax, row, col));
918+
919+
vector<uint8_t> outVecHalfTestNoRowwiseMinMax(row * out_cols_half);
920+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float16>(
921+
bit_rate,
922+
float16_test_input.data(),
923+
row,
924+
col,
925+
outVecHalfTestNoRowwiseMinMax.data());
926+
EXPECT_TRUE(isQEmbeddingClose<float16>(
927+
expected_output_half[bit_rate], outVecHalfTestNoRowwiseMinMax, row, col));
928+
EXPECT_TRUE(isQEmbeddingClose<float16>(
929+
outVectHalfTest, outVecHalfTestNoRowwiseMinMax, row, col));
930+
931+
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
932+
// Confirm that incorrect min and max values for each row in the input
933+
// of rowwise_min_max produces different results.
934+
// Since Windows & ARM are not yet supported, only run this test on x86_64.
935+
std::vector<float> float_test_input_incorrect_rowwise_min_max = {
936+
10,
937+
15,
938+
-14,
939+
1,
940+
};
941+
std::vector<float16> float16_test_input_incorrect_rowwise_min_max;
942+
float16_test_input_incorrect_rowwise_min_max.resize(
943+
float_test_input_incorrect_rowwise_min_max.size());
944+
std::transform(
945+
float_test_input_incorrect_rowwise_min_max.begin(),
946+
float_test_input_incorrect_rowwise_min_max.end(),
947+
float16_test_input_incorrect_rowwise_min_max.begin(),
948+
[](float input) { return cpu_float2half_rn(input); });
949+
950+
vector<uint8_t> outVecFloatTestIncorrectRowwiseMinMax(row * out_cols_half);
951+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
952+
bit_rate,
953+
float_test_input.data(),
954+
row,
955+
col,
956+
outVecFloatTestIncorrectRowwiseMinMax.data(),
957+
float_test_input_incorrect_rowwise_min_max.data());
958+
EXPECT_FALSE(isQEmbeddingClose<float16>(
959+
expected_output_half[bit_rate],
960+
outVecFloatTestIncorrectRowwiseMinMax,
961+
row,
962+
col));
963+
EXPECT_FALSE(isQEmbeddingClose<float16>(
964+
outVecFloatTestIncorrectRowwiseMinMax,
965+
outVecFloatTestNoRowwiseMinMax,
966+
row,
967+
col));
968+
969+
vector<uint8_t> outVecHalfTestIncorrectRowwiseMinMax(row * out_cols_half);
970+
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float16>(
971+
bit_rate,
972+
float16_test_input.data(),
973+
row,
974+
col,
975+
outVecHalfTestIncorrectRowwiseMinMax.data(),
976+
float16_test_input_incorrect_rowwise_min_max.data());
977+
EXPECT_FALSE(isQEmbeddingClose<float16>(
978+
expected_output_half[bit_rate],
979+
outVecHalfTestIncorrectRowwiseMinMax,
980+
row,
981+
col));
982+
EXPECT_FALSE(isQEmbeddingClose<float16>(
983+
outVecHalfTestIncorrectRowwiseMinMax,
984+
outVecHalfTestNoRowwiseMinMax,
985+
row,
986+
col));
987+
#endif
988+
}

0 commit comments

Comments
 (0)