diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h index 4701f07c30..29b3583bea 100644 --- a/src/FbgemmI8DepthwiseAvx2-inl.h +++ b/src/FbgemmI8DepthwiseAvx2-inl.h @@ -59,8 +59,7 @@ static ALWAYS_INLINE void requantize_( B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); } - __m256i min_v = _mm256_set1_epi8(static_cast(0)); - __m256i max_v = _mm256_set1_epi8(static_cast(255)); + __m256i min_v = _mm256_set1_epi8(0); if constexpr (A_SYMMETRIC) { assert(A_zero_point == 0 || col_offsets == nullptr); @@ -70,6 +69,9 @@ static ALWAYS_INLINE void requantize_( __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); __m256i permute_mask_v = + _mm256_set_epi32(0x03, 0x03, 0x02, 0x02, 0x01, 0x01, 0x00, 0x00); + + __m256i permute_mask_v_clamp = _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); constexpr int VLEN = 8; @@ -92,11 +94,10 @@ static ALWAYS_INLINE void requantize_( } else { static_assert(K_PER_G == 2); // Load row_offsets for 4 groups and broadcast by 2 times. - row_offset_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast(row_offsets + j / 2))), - permute_mask_v))); + row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(row_offsets + j / 2))), + permute_mask_v)); } if constexpr ( Q_GRAN == QuantizationGranularity::OUT_CHANNEL || @@ -105,11 +106,10 @@ static ALWAYS_INLINE void requantize_( reinterpret_cast(B_zero_point + j)); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { static_assert(K_PER_G == 2); - B_zero_point_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast(B_zero_point + j / 2))), - permute_mask_v))); + B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(B_zero_point + j / 2))), + permute_mask_v)); } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); x_v = _mm256_sub_epi32(x_v, row_offset_v); @@ -128,12 +128,10 @@ static ALWAYS_INLINE void requantize_( row_offset_v = _mm256_loadu_si256( reinterpret_cast(row_offsets + j + VLEN)); } else { - row_offset_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - row_offsets + (j + VLEN) / 2))), - permute_mask_v))); + row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(row_offsets + (j + VLEN) / 2))), + permute_mask_v)); } if constexpr ( Q_GRAN == QuantizationGranularity::OUT_CHANNEL || @@ -141,12 +139,10 @@ static ALWAYS_INLINE void requantize_( B_zero_point_v = _mm256_loadu_si256( reinterpret_cast(B_zero_point + j + VLEN)); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - B_zero_point + (j + VLEN) / 2))), - permute_mask_v))); + B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(B_zero_point + (j + VLEN) / 2))), + permute_mask_v)); } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); y_v = _mm256_sub_epi32(y_v, row_offset_v); @@ -164,12 +160,11 @@ static ALWAYS_INLINE void requantize_( row_offset_v = _mm256_loadu_si256( reinterpret_cast(row_offsets + j + 2 * VLEN)); } else { - row_offset_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - row_offsets + (j + 2 * VLEN) / 2))), - permute_mask_v))); + row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast( + row_offsets + (j + 2 * VLEN) / 2))), + permute_mask_v)); } if constexpr ( Q_GRAN == QuantizationGranularity::OUT_CHANNEL || @@ -177,12 +172,11 @@ static ALWAYS_INLINE void requantize_( B_zero_point_v = _mm256_loadu_si256( reinterpret_cast(B_zero_point + j + 2 * VLEN)); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - B_zero_point + (j + 2 * VLEN) / 2))), - permute_mask_v))); + B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast( + B_zero_point + (j + 2 * VLEN) / 2))), + permute_mask_v)); } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); z_v = _mm256_sub_epi32(z_v, row_offset_v); @@ -200,12 +194,11 @@ static ALWAYS_INLINE void requantize_( row_offset_v = _mm256_loadu_si256( reinterpret_cast(row_offsets + j + 3 * VLEN)); } else { - row_offset_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - row_offsets + (j + 3 * VLEN) / 2))), - permute_mask_v))); + row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast( + row_offsets + (j + 3 * VLEN) / 2))), + permute_mask_v)); } if constexpr ( Q_GRAN == QuantizationGranularity::OUT_CHANNEL || @@ -213,12 +206,11 @@ static ALWAYS_INLINE void requantize_( B_zero_point_v = _mm256_loadu_si256( reinterpret_cast(B_zero_point + j + 3 * VLEN)); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast( - B_zero_point + (j + 3 * VLEN) / 2))), - permute_mask_v))); + B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast( + B_zero_point + (j + 3 * VLEN) / 2))), + permute_mask_v)); } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); w_v = _mm256_sub_epi32(w_v, row_offset_v); @@ -260,31 +252,31 @@ static ALWAYS_INLINE void requantize_( x_bias_v = _mm256_div_ps( _mm256_loadu_ps( reinterpret_cast(bias + j + 0 * VLEN)), - _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(act_times_w_scale + j / 2)), - permute_mask_v))); + permute_mask_v)); y_bias_v = _mm256_div_ps( _mm256_loadu_ps( reinterpret_cast(bias + j + 1 * VLEN)), - _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(act_times_w_scale + (j + VLEN) / 2)), - permute_mask_v))); + permute_mask_v)); z_bias_v = _mm256_div_ps( _mm256_loadu_ps( reinterpret_cast(bias + j + 2 * VLEN)), - _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(act_times_w_scale + (j + 2 * VLEN) / 2)), - permute_mask_v))); + permute_mask_v)); w_bias_v = _mm256_div_ps( _mm256_loadu_ps( reinterpret_cast(bias + j + 3 * VLEN)), - _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(act_times_w_scale + (j + 3 * VLEN) / 2)), - permute_mask_v))); + permute_mask_v)); } else { x_bias_v = _mm256_mul_ps( _mm256_loadu_ps( @@ -341,9 +333,9 @@ static ALWAYS_INLINE void requantize_( (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + multiplier_v = _mm256_permutevar8x32_ps( _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)), - permute_mask_v)); + permute_mask_v); } __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); if constexpr ( @@ -351,9 +343,9 @@ static ALWAYS_INLINE void requantize_( (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + multiplier_v = _mm256_permutevar8x32_ps( _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + (j + VLEN) / 2)), - permute_mask_v)); + permute_mask_v); } __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); if constexpr ( @@ -361,10 +353,10 @@ static ALWAYS_INLINE void requantize_( (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + multiplier_v = _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(C_multiplier + (j + 2 * VLEN) / 2)), - permute_mask_v)); + permute_mask_v); } __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); if constexpr ( @@ -372,10 +364,10 @@ static ALWAYS_INLINE void requantize_( (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + multiplier_v = _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(C_multiplier + (j + 3 * VLEN) / 2)), - permute_mask_v)); + permute_mask_v); } __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); @@ -389,12 +381,11 @@ static ALWAYS_INLINE void requantize_( __m256i zw_packed_v = _mm256_adds_epi16( _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); - __m256i xyzw_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(xyzw_packed_v, max_v)); + __m256i xyzw_clamped_v = + _mm256_max_epu8(FUSE_RELU ? C_zero_point_epi8_v : min_v, xyzw_packed_v); xyzw_clamped_v = - _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v_clamp); _mm256_storeu_si256( reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); @@ -412,11 +403,10 @@ static ALWAYS_INLINE void requantize_( } else { static_assert(K_PER_G == 2); // Load row_offsets for 4 groups and broadcast by 2 times. - row_offset_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast(row_offsets + j / 2))), - permute_mask_v))); + row_offset_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(row_offsets + j / 2))), + permute_mask_v)); } if constexpr ( Q_GRAN == QuantizationGranularity::OUT_CHANNEL || @@ -425,11 +415,10 @@ static ALWAYS_INLINE void requantize_( reinterpret_cast(B_zero_point + j)); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { static_assert(K_PER_G == 2); - B_zero_point_v = - _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( - _mm256_castps128_ps256(_mm_loadu_ps( - reinterpret_cast(B_zero_point + j / 2))), - permute_mask_v))); + B_zero_point_v = _mm256_castps_si256(_mm256_permutevar8x32_ps( + _mm256_castps128_ps256(_mm_loadu_ps( + reinterpret_cast(B_zero_point + j / 2))), + permute_mask_v)); } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); x_v = _mm256_sub_epi32(x_v, row_offset_v); @@ -456,10 +445,10 @@ static ALWAYS_INLINE void requantize_( } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { x_bias_v = _mm256_div_ps( _mm256_loadu_ps(reinterpret_cast(bias + j)), - _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + _mm256_permutevar8x32_ps( _mm256_castps128_ps256( _mm_loadu_ps(act_times_w_scale + j / 2)), - permute_mask_v))); + permute_mask_v)); } else { x_bias_v = _mm256_mul_ps( _mm256_loadu_ps(reinterpret_cast(bias + j)), @@ -481,9 +470,9 @@ static ALWAYS_INLINE void requantize_( (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { multiplier_v = _mm256_loadu_ps(C_multiplier + j); } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( + multiplier_v = _mm256_permutevar8x32_ps( _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)), - permute_mask_v)); + permute_mask_v); } __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -492,11 +481,11 @@ static ALWAYS_INLINE void requantize_( _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), C_zero_point_epi16_v); x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); - __m256i x_clamped_v = _mm256_max_epu8( - FUSE_RELU ? C_zero_point_epi8_v : min_v, - _mm256_min_epu8(x_packed_v, max_v)); + __m256i x_clamped_v = + _mm256_max_epu8(FUSE_RELU ? C_zero_point_epi8_v : min_v, x_packed_v); - x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + x_clamped_v = + _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v_clamp); _mm_storel_epi64( reinterpret_cast<__m128i*>(C_uint8 + j),