Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 75 additions & 86 deletions src/FbgemmI8DepthwiseAvx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ static ALWAYS_INLINE void requantize_(
B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
}

__m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255));
__m256i min_v = _mm256_set1_epi8(0);

if constexpr (A_SYMMETRIC) {
assert(A_zero_point == 0 || col_offsets == nullptr);
Expand All @@ -70,6 +69,9 @@ static ALWAYS_INLINE void requantize_(
__m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);

__m256i permute_mask_v =
_mm256_set_epi32(0x03, 0x03, 0x02, 0x02, 0x01, 0x01, 0x00, 0x00);

__m256i permute_mask_v_clamp =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);

constexpr int VLEN = 8;
Expand All @@ -92,11 +94,10 @@ static ALWAYS_INLINE void requantize_(
} else {
static_assert(K_PER_G == 2);
// Load row_offsets for 4 groups and broadcast by 2 times.
row_offset_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(row_offsets + j / 2))),
permute_mask_v)));
row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(row_offsets + j / 2))),
permute_mask_v));
}
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
Expand All @@ -105,11 +106,10 @@ static ALWAYS_INLINE void requantize_(
reinterpret_cast<const __m256i*>(B_zero_point + j));
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
static_assert(K_PER_G == 2);
B_zero_point_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(B_zero_point + j / 2))),
permute_mask_v)));
B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(B_zero_point + j / 2))),
permute_mask_v));
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
x_v = _mm256_sub_epi32(x_v, row_offset_v);
Expand All @@ -128,25 +128,21 @@ static ALWAYS_INLINE void requantize_(
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
} else {
row_offset_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
row_offsets + (j + VLEN) / 2))),
permute_mask_v)));
row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(row_offsets + (j + VLEN) / 2))),
permute_mask_v));
}
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(B_zero_point + j + VLEN));
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
B_zero_point_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
B_zero_point + (j + VLEN) / 2))),
permute_mask_v)));
B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(B_zero_point + (j + VLEN) / 2))),
permute_mask_v));
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
y_v = _mm256_sub_epi32(y_v, row_offset_v);
Expand All @@ -164,25 +160,23 @@ static ALWAYS_INLINE void requantize_(
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
} else {
row_offset_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
row_offsets + (j + 2 * VLEN) / 2))),
permute_mask_v)));
row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
row_offsets + (j + 2 * VLEN) / 2))),
permute_mask_v));
}
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(B_zero_point + j + 2 * VLEN));
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
B_zero_point_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
B_zero_point + (j + 2 * VLEN) / 2))),
permute_mask_v)));
B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
B_zero_point + (j + 2 * VLEN) / 2))),
permute_mask_v));
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
z_v = _mm256_sub_epi32(z_v, row_offset_v);
Expand All @@ -200,25 +194,23 @@ static ALWAYS_INLINE void requantize_(
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
} else {
row_offset_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
row_offsets + (j + 3 * VLEN) / 2))),
permute_mask_v)));
row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
row_offsets + (j + 3 * VLEN) / 2))),
permute_mask_v));
}
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(B_zero_point + j + 3 * VLEN));
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
B_zero_point_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
B_zero_point + (j + 3 * VLEN) / 2))),
permute_mask_v)));
B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(
B_zero_point + (j + 3 * VLEN) / 2))),
permute_mask_v));
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
w_v = _mm256_sub_epi32(w_v, row_offset_v);
Expand Down Expand Up @@ -260,31 +252,31 @@ static ALWAYS_INLINE void requantize_(
x_bias_v = _mm256_div_ps(
_mm256_loadu_ps(
reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(act_times_w_scale + j / 2)),
permute_mask_v)));
permute_mask_v));
y_bias_v = _mm256_div_ps(
_mm256_loadu_ps(
reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(act_times_w_scale + (j + VLEN) / 2)),
permute_mask_v)));
permute_mask_v));
z_bias_v = _mm256_div_ps(
_mm256_loadu_ps(
reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(act_times_w_scale + (j + 2 * VLEN) / 2)),
permute_mask_v)));
permute_mask_v));
w_bias_v = _mm256_div_ps(
_mm256_loadu_ps(
reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(act_times_w_scale + (j + 3 * VLEN) / 2)),
permute_mask_v)));
permute_mask_v));
} else {
x_bias_v = _mm256_mul_ps(
_mm256_loadu_ps(
Expand Down Expand Up @@ -341,41 +333,41 @@ static ALWAYS_INLINE void requantize_(
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
multiplier_v = _mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)),
permute_mask_v));
permute_mask_v);
}
__m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
multiplier_v = _mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + (j + VLEN) / 2)),
permute_mask_v));
permute_mask_v);
}
__m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
multiplier_v = _mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(C_multiplier + (j + 2 * VLEN) / 2)),
permute_mask_v));
permute_mask_v);
}
__m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
multiplier_v = _mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(C_multiplier + (j + 3 * VLEN) / 2)),
permute_mask_v));
permute_mask_v);
}
__m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);

Expand All @@ -389,12 +381,11 @@ static ALWAYS_INLINE void requantize_(
__m256i zw_packed_v = _mm256_adds_epi16(
_mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
__m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
__m256i xyzw_clamped_v = _mm256_max_epu8(
FUSE_RELU ? C_zero_point_epi8_v : min_v,
_mm256_min_epu8(xyzw_packed_v, max_v));
__m256i xyzw_clamped_v =
_mm256_max_epu8(FUSE_RELU ? C_zero_point_epi8_v : min_v, xyzw_packed_v);

xyzw_clamped_v =
_mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
_mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v_clamp);

_mm256_storeu_si256(
reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
Expand All @@ -412,11 +403,10 @@ static ALWAYS_INLINE void requantize_(
} else {
static_assert(K_PER_G == 2);
// Load row_offsets for 4 groups and broadcast by 2 times.
row_offset_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(row_offsets + j / 2))),
permute_mask_v)));
row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(row_offsets + j / 2))),
permute_mask_v));
}
if constexpr (
Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
Expand All @@ -425,11 +415,10 @@ static ALWAYS_INLINE void requantize_(
reinterpret_cast<const __m256i*>(B_zero_point + j));
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
static_assert(K_PER_G == 2);
B_zero_point_v =
_mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(B_zero_point + j / 2))),
permute_mask_v)));
B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(
reinterpret_cast<const float*>(B_zero_point + j / 2))),
permute_mask_v));
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
x_v = _mm256_sub_epi32(x_v, row_offset_v);
Expand All @@ -456,10 +445,10 @@ static ALWAYS_INLINE void requantize_(
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
x_bias_v = _mm256_div_ps(
_mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
_mm256_permutevar8x32_ps(
_mm256_castps128_ps256(
_mm_loadu_ps(act_times_w_scale + j / 2)),
permute_mask_v)));
permute_mask_v));
} else {
x_bias_v = _mm256_mul_ps(
_mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
Expand All @@ -481,9 +470,9 @@ static ALWAYS_INLINE void requantize_(
(Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
multiplier_v = _mm256_loadu_ps(C_multiplier + j);
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
multiplier_v = _mm256_permutevar8x32_ps(
_mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)),
permute_mask_v));
permute_mask_v);
}
__m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
Expand All @@ -492,11 +481,11 @@ static ALWAYS_INLINE void requantize_(
_mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
C_zero_point_epi16_v);
x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
__m256i x_clamped_v = _mm256_max_epu8(
FUSE_RELU ? C_zero_point_epi8_v : min_v,
_mm256_min_epu8(x_packed_v, max_v));
__m256i x_clamped_v =
_mm256_max_epu8(FUSE_RELU ? C_zero_point_epi8_v : min_v, x_packed_v);

x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
x_clamped_v =
_mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v_clamp);

_mm_storel_epi64(
reinterpret_cast<__m128i*>(C_uint8 + j),
Expand Down
Loading