Skip to content

Commit 6ef35c3

Browse files
authored
Revert "[libc] Implement branchless head-tail comparison for bcmp (#107540)"
This reverts commit 66a0329.
1 parent 2992d3d commit 6ef35c3

File tree

2 files changed

+41
-77
lines changed

2 files changed

+41
-77
lines changed

libc/src/string/memory_utils/op_x86.h

Lines changed: 20 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,6 @@ struct Memcpy {
7373
namespace LIBC_NAMESPACE_DECL {
7474
namespace generic {
7575

76-
// Not equals: returns non-zero iff values at head or tail differ.
77-
// This function typically loads more data than necessary when the two buffer
78-
// differs.
79-
template <typename T>
80-
LIBC_INLINE uint32_t branchless_head_tail_neq(CPtr p1, CPtr p2, size_t count) {
81-
static_assert(cpp::is_integral_v<T>);
82-
return neq<T>(p1, p2, 0) | neq<T>(p1, p2, count - sizeof(T));
83-
}
84-
8576
///////////////////////////////////////////////////////////////////////////////
8677
// Specializations for uint16_t
8778
template <> struct cmp_is_expensive<uint16_t> : public cpp::false_type {};
@@ -154,11 +145,6 @@ LIBC_INLINE MemcmpReturnType cmp_neq<uint64_t>(CPtr p1, CPtr p2,
154145
#if defined(__SSE4_1__)
155146
template <> struct is_vector<__m128i> : cpp::true_type {};
156147
template <> struct cmp_is_expensive<__m128i> : cpp::true_type {};
157-
LIBC_INLINE __m128i load_and_xor_m128i(CPtr p1, CPtr p2, size_t offset) {
158-
const auto a = load<__m128i>(p1, offset);
159-
const auto b = load<__m128i>(p2, offset);
160-
return _mm_xor_si128(a, b);
161-
}
162148
LIBC_INLINE __m128i bytewise_max(__m128i a, __m128i b) {
163149
return _mm_max_epu8(a, b);
164150
}
@@ -170,21 +156,17 @@ LIBC_INLINE uint16_t big_endian_cmp_mask(__m128i max, __m128i value) {
170156
return static_cast<uint16_t>(
171157
_mm_movemask_epi8(bytewise_reverse(_mm_cmpeq_epi8(max, value))));
172158
}
173-
LIBC_INLINE bool is_zero(__m128i value) {
174-
return _mm_testz_si128(value, value) == 1;
175-
}
176159
template <> LIBC_INLINE bool eq<__m128i>(CPtr p1, CPtr p2, size_t offset) {
177-
return is_zero(load_and_xor_m128i(p1, p2, offset));
160+
const auto a = load<__m128i>(p1, offset);
161+
const auto b = load<__m128i>(p2, offset);
162+
const auto xored = _mm_xor_si128(a, b);
163+
return _mm_testz_si128(xored, xored) == 1; // 1 iff xored == 0
178164
}
179165
template <> LIBC_INLINE uint32_t neq<__m128i>(CPtr p1, CPtr p2, size_t offset) {
180-
return !is_zero(load_and_xor_m128i(p1, p2, offset));
181-
}
182-
template <>
183-
LIBC_INLINE uint32_t branchless_head_tail_neq<__m128i>(CPtr p1, CPtr p2,
184-
size_t count) {
185-
const __m128i head = load_and_xor_m128i(p1, p2, 0);
186-
const __m128i tail = load_and_xor_m128i(p1, p2, count - sizeof(__m128i));
187-
return !is_zero(_mm_or_si128(head, tail));
166+
const auto a = load<__m128i>(p1, offset);
167+
const auto b = load<__m128i>(p2, offset);
168+
const auto xored = _mm_xor_si128(a, b);
169+
return _mm_testz_si128(xored, xored) == 0; // 0 iff xored != 0
188170
}
189171
template <>
190172
LIBC_INLINE MemcmpReturnType cmp_neq<__m128i>(CPtr p1, CPtr p2, size_t offset) {
@@ -203,34 +185,19 @@ LIBC_INLINE MemcmpReturnType cmp_neq<__m128i>(CPtr p1, CPtr p2, size_t offset) {
203185
#if defined(__AVX__)
204186
template <> struct is_vector<__m256i> : cpp::true_type {};
205187
template <> struct cmp_is_expensive<__m256i> : cpp::true_type {};
206-
LIBC_INLINE __m256i xor_m256i(__m256i a, __m256i b) {
207-
return _mm256_castps_si256(
208-
_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b)));
209-
}
210-
LIBC_INLINE __m256i or_m256i(__m256i a, __m256i b) {
211-
return _mm256_castps_si256(
212-
_mm256_or_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b)));
213-
}
214-
LIBC_INLINE __m256i load_and_xor_m256i(CPtr p1, CPtr p2, size_t offset) {
188+
template <> LIBC_INLINE bool eq<__m256i>(CPtr p1, CPtr p2, size_t offset) {
215189
const auto a = load<__m256i>(p1, offset);
216190
const auto b = load<__m256i>(p2, offset);
217-
return xor_m256i(a, b);
218-
}
219-
LIBC_INLINE bool is_zero(__m256i value) {
220-
return _mm256_testz_si256(value, value) == 1;
221-
}
222-
template <> LIBC_INLINE bool eq<__m256i>(CPtr p1, CPtr p2, size_t offset) {
223-
return is_zero(load_and_xor_m256i(p1, p2, offset));
191+
const auto xored = _mm256_castps_si256(
192+
_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b)));
193+
return _mm256_testz_si256(xored, xored) == 1; // 1 iff xored == 0
224194
}
225195
template <> LIBC_INLINE uint32_t neq<__m256i>(CPtr p1, CPtr p2, size_t offset) {
226-
return !is_zero(load_and_xor_m256i(p1, p2, offset));
227-
}
228-
template <>
229-
LIBC_INLINE uint32_t branchless_head_tail_neq<__m256i>(CPtr p1, CPtr p2,
230-
size_t count) {
231-
const __m256i head = load_and_xor_m256i(p1, p2, 0);
232-
const __m256i tail = load_and_xor_m256i(p1, p2, count - sizeof(__m256i));
233-
return !is_zero(or_m256i(head, tail));
196+
const auto a = load<__m256i>(p1, offset);
197+
const auto b = load<__m256i>(p2, offset);
198+
const auto xored = _mm256_castps_si256(
199+
_mm256_xor_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b)));
200+
return _mm256_testz_si256(xored, xored) == 0; // 0 iff xored != 0
234201
}
235202
#endif // __AVX__
236203

@@ -345,22 +312,9 @@ template <> LIBC_INLINE bool eq<__m512i>(CPtr p1, CPtr p2, size_t offset) {
345312
template <> LIBC_INLINE uint32_t neq<__m512i>(CPtr p1, CPtr p2, size_t offset) {
346313
const auto a = load<__m512i>(p1, offset);
347314
const auto b = load<__m512i>(p2, offset);
348-
return _mm512_cmpneq_epi8_mask(a, b) != 0;
349-
}
350-
LIBC_INLINE __m512i load_and_xor_m512i(CPtr p1, CPtr p2, size_t offset) {
351-
const auto a = load<__m512i>(p1, offset);
352-
const auto b = load<__m512i>(p2, offset);
353-
return _mm512_xor_epi64(a, b);
354-
}
355-
LIBC_INLINE bool is_zero(__m512i value) {
356-
return _mm512_test_epi32_mask(value, value) == 0;
357-
}
358-
template <>
359-
LIBC_INLINE uint32_t branchless_head_tail_neq<__m512i>(CPtr p1, CPtr p2,
360-
size_t count) {
361-
const __m512i head = load_and_xor_m512i(p1, p2, 0);
362-
const __m512i tail = load_and_xor_m512i(p1, p2, count - sizeof(__m512i));
363-
return !is_zero(_mm512_or_epi64(head, tail));
315+
const uint64_t xored = _mm512_cmpneq_epi8_mask(a, b);
316+
return static_cast<uint32_t>(xored >> 32) |
317+
static_cast<uint32_t>(xored & 0xFFFFFFFF);
364318
}
365319
template <>
366320
LIBC_INLINE MemcmpReturnType cmp_neq<__m512i>(CPtr p1, CPtr p2, size_t offset) {

libc/src/string/memory_utils/x86_64/inline_bcmp.h

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ inline_bcmp_generic_gt16(CPtr p1, CPtr p2, size_t count) {
2727
[[maybe_unused]] LIBC_INLINE BcmpReturnType
2828
inline_bcmp_x86_sse41_gt16(CPtr p1, CPtr p2, size_t count) {
2929
if (count <= 32)
30-
return generic::branchless_head_tail_neq<__m128i>(p1, p2, count);
30+
return generic::Bcmp<__m128i>::head_tail(p1, p2, count);
3131
return generic::Bcmp<__m128i>::loop_and_tail_align_above(256, p1, p2, count);
3232
}
3333
#endif // __SSE4_1__
@@ -36,9 +36,9 @@ inline_bcmp_x86_sse41_gt16(CPtr p1, CPtr p2, size_t count) {
3636
[[maybe_unused]] LIBC_INLINE BcmpReturnType
3737
inline_bcmp_x86_avx_gt16(CPtr p1, CPtr p2, size_t count) {
3838
if (count <= 32)
39-
return generic::branchless_head_tail_neq<__m128i>(p1, p2, count);
39+
return generic::Bcmp<__m128i>::head_tail(p1, p2, count);
4040
if (count <= 64)
41-
return generic::branchless_head_tail_neq<__m256i>(p1, p2, count);
41+
return generic::Bcmp<__m256i>::head_tail(p1, p2, count);
4242
return generic::Bcmp<__m256i>::loop_and_tail_align_above(256, p1, p2, count);
4343
}
4444
#endif // __AVX__
@@ -47,11 +47,11 @@ inline_bcmp_x86_avx_gt16(CPtr p1, CPtr p2, size_t count) {
4747
[[maybe_unused]] LIBC_INLINE BcmpReturnType
4848
inline_bcmp_x86_avx512bw_gt16(CPtr p1, CPtr p2, size_t count) {
4949
if (count <= 32)
50-
return generic::branchless_head_tail_neq<__m128i>(p1, p2, count);
50+
return generic::Bcmp<__m128i>::head_tail(p1, p2, count);
5151
if (count <= 64)
52-
return generic::branchless_head_tail_neq<__m256i>(p1, p2, count);
52+
return generic::Bcmp<__m256i>::head_tail(p1, p2, count);
5353
if (count <= 128)
54-
return generic::branchless_head_tail_neq<__m512i>(p1, p2, count);
54+
return generic::Bcmp<__m512i>::head_tail(p1, p2, count);
5555
return generic::Bcmp<__m512i>::loop_and_tail_align_above(256, p1, p2, count);
5656
}
5757
#endif // __AVX512BW__
@@ -62,12 +62,22 @@ inline_bcmp_x86_avx512bw_gt16(CPtr p1, CPtr p2, size_t count) {
6262
return BcmpReturnType::zero();
6363
if (count == 1)
6464
return generic::Bcmp<uint8_t>::block(p1, p2);
65-
if (count <= 4)
66-
return generic::branchless_head_tail_neq<uint16_t>(p1, p2, count);
67-
if (count <= 8)
68-
return generic::branchless_head_tail_neq<uint32_t>(p1, p2, count);
65+
if (count == 2)
66+
return generic::Bcmp<uint16_t>::block(p1, p2);
67+
if (count == 3)
68+
return generic::BcmpSequence<uint16_t, uint8_t>::block(p1, p2);
69+
if (count == 4)
70+
return generic::Bcmp<uint32_t>::block(p1, p2);
71+
if (count == 5)
72+
return generic::BcmpSequence<uint32_t, uint8_t>::block(p1, p2);
73+
if (count == 6)
74+
return generic::BcmpSequence<uint32_t, uint16_t>::block(p1, p2);
75+
if (count == 7)
76+
return generic::BcmpSequence<uint32_t, uint16_t, uint8_t>::block(p1, p2);
77+
if (count == 8)
78+
return generic::Bcmp<uint64_t>::block(p1, p2);
6979
if (count <= 16)
70-
return generic::branchless_head_tail_neq<uint64_t>(p1, p2, count);
80+
return generic::Bcmp<uint64_t>::head_tail(p1, p2, count);
7181
#if defined(__AVX512BW__)
7282
return inline_bcmp_x86_avx512bw_gt16(p1, p2, count);
7383
#elif defined(__AVX__)

0 commit comments

Comments
 (0)