Skip to content

Commit d616a95

Browse files
Refactor to template-based implementation
1 parent 61e0181 commit d616a95

File tree

1 file changed

+47
-60
lines changed

1 file changed

+47
-60
lines changed

libc/src/string/memory_utils/x86_64/inline_strlen.h

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,97 +15,84 @@
1515

1616
namespace LIBC_NAMESPACE_DECL {
1717

18-
namespace sse2 {
19-
[[maybe_unused]] LIBC_INLINE size_t string_length(const char *src) {
20-
using Vector __attribute__((may_alias)) = __m128i;
18+
namespace string_length_internal {
19+
// Return a bit-mask with the nth bit set if the nth-byte in block_ptr is zero.
20+
template<typename Vector, typename Mask> Mask CompareAndMask(const Vector *block_ptr);
2121

22-
Vector z = _mm_setzero_si128();
22+
template <typename Vector, typename Mask, decltype(CompareAndMask<Vector, Mask>)>
23+
size_t string_length_vector(const char *src) {
2324
uintptr_t misalign_bytes = reinterpret_cast<uintptr_t>(src) % sizeof(Vector);
25+
2426
const Vector *block_ptr =
2527
reinterpret_cast<const Vector *>(src - misalign_bytes);
26-
Vector v = _mm_load_si128(block_ptr);
27-
Vector vcmp = _mm_cmpeq_epi8(z, v);
28-
// shift away results in irrelevant bytes.
29-
uint32_t cmp = _mm_movemask_epi8(vcmp) >> misalign_bytes;
28+
auto cmp = CompareAndMask<Vector, Mask>(block_ptr) >> misalign_bytes;
3029
if (cmp)
3130
return cpp::countr_zero(cmp);
3231

3332
while (true) {
3433
block_ptr++;
35-
v = _mm_load_si128(block_ptr);
36-
vcmp = _mm_cmpeq_epi8(z, v);
37-
cmp = _mm_movemask_epi8(vcmp);
34+
cmp = CompareAndMask<Vector, Mask>(block_ptr);
3835
if (cmp)
3936
return static_cast<size_t>(reinterpret_cast<uintptr_t>(block_ptr) -
4037
reinterpret_cast<uintptr_t>(src) +
4138
cpp::countr_zero(cmp));
4239
}
4340
}
41+
42+
template <>
43+
uint32_t CompareAndMask<__m128i, uint32_t>(const __m128i *block_ptr) {
44+
__m128i v = _mm_load_si128(block_ptr);
45+
__m128i z = _mm_setzero_si128();
46+
__m128i c = _mm_cmpeq_epi8(z, v);
47+
return _mm_movemask_epi8(c);
48+
}
49+
50+
namespace sse2 {
51+
size_t string_length(const char *src) {
52+
return string_length_vector<__m128i, uint32_t, CompareAndMask<__m128i, uint32_t>>(
53+
src);
54+
}
4455
} // namespace sse2
4556

4657
#if defined(__AVX2__)
47-
namespace avx2 {
48-
[[maybe_unused]] LIBC_INLINE size_t string_length(const char *src) {
49-
using Vector __attribute__((may_alias)) = __mm256i;
50-
51-
Vector z = _mm256_setzero_si256();
52-
uintptr_t misalign_bytes = reinterpret_cast<uintptr_t>(src) % sizeof(Vector);
53-
const Vector *block_ptr =
54-
reinterpret_cast<const Vector *>(src - misalign_bytes);
55-
Vector v = _mm256_load_si256(block_ptr);
56-
Vector vcmp = _mm256_cmpeq_epi8(z, v);
57-
// shift away results in irrelevant bytes.
58-
int cmp = _mm256_movemask_epi8(vcmp) >> misalign_bytes;
59-
if (cmp)
60-
return cpp::countr_zero(cmp);
58+
template <>
59+
uint32_t CompareAndMask<__m256i, uint32_t>(const __m256i *block_ptr) {
60+
__m256i v = _mm256_load_si256(block_ptr);
61+
__m256i z = _mm256_setzero_si256();
62+
__m256i c = _mm256_cmpeq_epi8(z, v);
63+
return _mm256_movemask_epi8(c);
64+
}
6165

62-
while (true) {
63-
block_ptr++;
64-
v = _mm256_load_si256(block_ptr);
65-
vcmp = _mm256_cmpeq_epi8(z, v);
66-
cmp = _mm256_movemask_epi8(vcmp);
67-
if (cmp)
68-
return static_cast<size_t>(reinterpret_cast<uintptr_t>(block_ptr) -
69-
reinterpret_cast<uintptr_t>(src) +
70-
cpp::countr_zero(cmp));
71-
}
66+
namespace avx2 {
67+
size_t string_length(const char* src) {
68+
return string_length_vector<__m256i, uint32_t, CompareAndMask<__m256i, uint32_t>>(src);
7269
}
7370
} // namespace avx2
7471
#endif
7572

7673
#if defined(__AVX512F__)
74+
template <>
75+
__mmask64
76+
CompareAndMask<__m512i, __mmask64> (const __m512i *block_ptr)
77+
{
78+
__m512i v = _mm512_load_si512(block_ptr);
79+
__m512i z = _mm512_setzero_si512();
80+
return _mm512_cmp_epu8_mask(z, v, _MM_CMPINT_EQ);
81+
}
7782
namespace avx512 {
78-
[[maybe_unused]] LIBC_INLINE size_t string_length(const char *src) {
79-
using Vector __attribute__((may_alias)) = __mm512i;
80-
81-
Vector z = _mm512_setzero_si512();
82-
uintptr_t misalign_bytes = reinterpret_cast<uintptr_t>(src) % sizeof(Vector);
83-
const Vector *block_ptr =
84-
reinterpret_cast<const Vector *>(src - misalign_bytes);
85-
Vector v = _mm512_load_si512(block_ptr);
86-
__mmask64 cmp = _mm512_cmp_epu8_mask(z, v, _MM_CMPINT_EQ) >> misalign_bytes;
87-
if (cmp)
88-
return cpp::countr_zero(cmp);
89-
90-
while (true) {
91-
block_ptr++;
92-
Vector v = _mm512_load_si512(block_ptr);
93-
__mmask64 cmp = _mm512_cmp_epu8_mask(z, v, _MM_CMPINT_EQ);
94-
if (cmp)
95-
return static_cast<size_t>(reinterpret_cast<uintptr_t>(block_ptr) -
96-
reinterpret_cast<uintptr_t>(src) +
97-
cpp::countr_zero(cmp));
98-
}
83+
size_t string_length(const char* src) {
84+
return string_length_vector<__m512i, __mmask64, CompareAndMask<__m512i, __mmask64>>(src);
9985
}
10086
} // namespace avx512
10187
#endif
88+
} // string_length_internal
10289

10390
#if defined(__AVX512F__)
104-
namespace string_length_impl = avx512;
105-
#elif defined(__AVX2__)
106-
namespace string_length_impl = avx2;
91+
namespace string_length_impl = string_length_internal::avx512;
92+
#elif defined (__AVX2__)
93+
namespace string_length_impl = string_length_internal::avx2;
10794
#else
108-
namespace string_length_impl = sse2;
95+
namespace string_length_impl = string_length_internal::sse2;
10996
#endif
11097

11198
} // namespace LIBC_NAMESPACE_DECL

0 commit comments

Comments
 (0)