@@ -1563,7 +1563,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
1563
1563
const InputType* input,
1564
1564
size_t input_rows,
1565
1565
int input_columns,
1566
- std::uint8_t * output) {
1566
+ std::uint8_t * output,
1567
+ const InputType* rowwise_min_max) {
1567
1568
static_assert (
1568
1569
std::is_same<InputType, float >() || std::is_same<InputType, float16>(),
1569
1570
" Only float and float16 types are allowed." );
@@ -1574,6 +1575,8 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
1574
1575
2 * sizeof (std::uint16_t );
1575
1576
1576
1577
float * input_row_float_for_fp16 = nullptr ;
1578
+ float min_max_row_float_for_fp16[kRowwiseMinMaxNumCols ];
1579
+ const auto is_valid_rowwise_min_max = (rowwise_min_max != nullptr );
1577
1580
if constexpr (std::is_same<InputType, float16>()) {
1578
1581
input_row_float_for_fp16 = static_cast <float *>(
1579
1582
fbgemmAlignedAlloc (64 , input_columns * sizeof (float )));
@@ -1591,48 +1594,72 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
1591
1594
input_row_float = input_row_float_for_fp16;
1592
1595
}
1593
1596
1597
+ const float * min_max_row_float = nullptr ;
1598
+ if (is_valid_rowwise_min_max) {
1599
+ const InputType* min_max_row =
1600
+ rowwise_min_max + row * kRowwiseMinMaxNumCols ;
1601
+
1602
+ if constexpr (std::is_same_v<InputType, float >) {
1603
+ min_max_row_float = reinterpret_cast <const float *>(min_max_row);
1604
+ } else {
1605
+ min_max_row_float_for_fp16[0 ] = halfToFloat (min_max_row[0 ]);
1606
+ min_max_row_float_for_fp16[1 ] = halfToFloat (min_max_row[1 ]);
1607
+ min_max_row_float = min_max_row_float_for_fp16;
1608
+ }
1609
+ }
1610
+
1594
1611
std::uint8_t * output_row = output + row * output_columns;
1595
1612
std::uint16_t * output_row_scale_bias = reinterpret_cast <std::uint16_t *>(
1596
1613
output_row +
1597
1614
(input_columns + NUM_ELEM_PER_BYTE - 1 ) / NUM_ELEM_PER_BYTE);
1598
1615
1599
1616
float minimum_element = FLT_MAX;
1600
1617
float maximum_element = -FLT_MAX;
1601
- __m256 min_v = _mm256_set1_ps (minimum_element);
1602
- __m256 max_v = _mm256_set1_ps (maximum_element);
1618
+ if (is_valid_rowwise_min_max) {
1619
+ minimum_element = min_max_row_float[0 ];
1620
+ maximum_element = min_max_row_float[1 ];
1603
1621
1604
- int col = 0 ;
1605
- for (col = 0 ; col < input_columns / VLEN * VLEN; col += VLEN) {
1606
- __m256 in_v;
1607
- if constexpr (std::is_same<InputType, float >()) {
1608
- in_v = _mm256_loadu_ps (input_row_float + col);
1609
- } else {
1610
- __m128i in_half_v =
1611
- _mm_loadu_si128 (reinterpret_cast <const __m128i*>(input_row + col));
1612
- in_v = _mm256_cvtph_ps (in_half_v);
1613
- _mm256_store_ps (input_row_float_for_fp16 + col, in_v);
1622
+ for (int col = 0 ; col < input_columns; ++col) {
1623
+ if constexpr (std::is_same<InputType, float16>()) {
1624
+ input_row_float_for_fp16[col] = halfToFloat (input_row[col]);
1625
+ }
1614
1626
}
1627
+ } else {
1628
+ __m256 min_v = _mm256_set1_ps (minimum_element);
1629
+ __m256 max_v = _mm256_set1_ps (maximum_element);
1630
+ int col = 0 ;
1631
+ for (col = 0 ; col < input_columns / VLEN * VLEN; col += VLEN) {
1632
+ __m256 in_v;
1633
+ if constexpr (std::is_same<InputType, float >()) {
1634
+ in_v = _mm256_loadu_ps (input_row_float + col);
1635
+ } else {
1636
+ __m128i in_half_v = _mm_loadu_si128 (
1637
+ reinterpret_cast <const __m128i*>(input_row + col));
1638
+ in_v = _mm256_cvtph_ps (in_half_v);
1639
+ _mm256_store_ps (input_row_float_for_fp16 + col, in_v);
1640
+ }
1615
1641
1616
- min_v = _mm256_min_ps (min_v, in_v);
1617
- max_v = _mm256_max_ps (max_v, in_v);
1618
- }
1619
- alignas (64 ) float min_buf[VLEN], max_buf[VLEN];
1620
- _mm256_store_ps (min_buf, min_v);
1621
- _mm256_store_ps (max_buf, max_v);
1622
- for (int i = 0 ; i < VLEN; ++i) {
1623
- minimum_element = std::min (minimum_element, min_buf[i]);
1624
- maximum_element = std::max (maximum_element, max_buf[i]);
1625
- }
1642
+ min_v = _mm256_min_ps (min_v, in_v);
1643
+ max_v = _mm256_max_ps (max_v, in_v);
1644
+ }
1645
+ alignas (64 ) float min_buf[VLEN], max_buf[VLEN];
1646
+ _mm256_store_ps (min_buf, min_v);
1647
+ _mm256_store_ps (max_buf, max_v);
1648
+ for (int i = 0 ; i < VLEN; ++i) {
1649
+ minimum_element = std::min (minimum_element, min_buf[i]);
1650
+ maximum_element = std::max (maximum_element, max_buf[i]);
1651
+ }
1626
1652
1627
- for (; col < input_columns; ++col) {
1628
- if constexpr (std::is_same<InputType, float >()) {
1629
- minimum_element = std::min (minimum_element, input_row_float[col]);
1630
- maximum_element = std::max (maximum_element, input_row_float[col]);
1631
- } else {
1632
- float element = halfToFloat (input_row[col]);
1633
- input_row_float_for_fp16[col] = element;
1634
- minimum_element = std::min (minimum_element, element);
1635
- maximum_element = std::max (maximum_element, element);
1653
+ for (; col < input_columns; ++col) {
1654
+ if constexpr (std::is_same<InputType, float >()) {
1655
+ minimum_element = std::min (minimum_element, input_row_float[col]);
1656
+ maximum_element = std::max (maximum_element, input_row_float[col]);
1657
+ } else {
1658
+ float element = halfToFloat (input_row[col]);
1659
+ input_row_float_for_fp16[col] = element;
1660
+ minimum_element = std::min (minimum_element, element);
1661
+ maximum_element = std::max (maximum_element, element);
1662
+ }
1636
1663
}
1637
1664
}
1638
1665
@@ -1657,12 +1684,12 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
1657
1684
1658
1685
output_row_scale_bias[0 ] = floatToHalf (scale);
1659
1686
1660
- col = 0 ;
1687
+ int col = 0 ;
1661
1688
if constexpr (BIT_RATE == 2 || BIT_RATE == 4 ) {
1662
1689
__m256i permute_mask1_v =
1663
1690
_mm256_set_epi32 (0x07 , 0x03 , 0x06 , 0x02 , 0x05 , 0x01 , 0x04 , 0x00 );
1664
1691
__m256 inverse_scale_v = _mm256_set1_ps (inverse_scale);
1665
- min_v = _mm256_set1_ps (minimum_element);
1692
+ __m256 min_v = _mm256_set1_ps (minimum_element);
1666
1693
1667
1694
for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) {
1668
1695
__m256i x_rounded_v = _mm256_cvtps_epi32 (_mm256_mul_ps (
@@ -1778,7 +1805,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
1778
1805
1779
1806
const int64_t output_columns = input_columns + 2 * sizeof (float );
1780
1807
float * input_row_float_for_fp16 = nullptr ;
1781
- float min_max_row_float_for_fp16[2 ];
1808
+ float min_max_row_float_for_fp16[kRowwiseMinMaxNumCols ];
1782
1809
const auto is_valid_rowwise_min_max = (rowwise_min_max != nullptr );
1783
1810
if constexpr (std::is_same_v<InputType, float16>) {
1784
1811
input_row_float_for_fp16 = static_cast <float *>(
@@ -1798,7 +1825,8 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
1798
1825
1799
1826
const float * min_max_row_float = nullptr ;
1800
1827
if (is_valid_rowwise_min_max) {
1801
- const InputType* min_max_row = rowwise_min_max + row * 2 ;
1828
+ const InputType* min_max_row =
1829
+ rowwise_min_max + row * kRowwiseMinMaxNumCols ;
1802
1830
1803
1831
if constexpr (std::is_same_v<InputType, float >) {
1804
1832
min_max_row_float = reinterpret_cast <const float *>(min_max_row);
@@ -2205,7 +2233,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
2205
2233
const type* input, \
2206
2234
size_t input_rows, \
2207
2235
int input_columns, \
2208
- std::uint8_t * output); \
2236
+ std::uint8_t * output, \
2237
+ const type* rowwise_min_max); \
2209
2238
template void \
2210
2239
FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<type, bit_rate>( \
2211
2240
const std::uint8_t * input, \
0 commit comments