Skip to content

Commit 41edcb9

Browse files
authored
Optional SIMD memcmp (#603)
This the last PR I'm putting forward for #580. `memcmp` has the advantage that we know the lengths and can access the entire buffers (no undefined behavior). It has the difficulty that the buffers may not share an alingment, so it uses `wasm_v128_load`. It uses SIMD if there are 16 or more bytes to read, otherwise it fallbacks to scalar. If the number of bytes is larger than 16, but not a multiple of 16, the second iteration retests some already tested bytes, to "align" the remaining length to a multiple of 16. Making the first (rather than the last) iteration special unnecessarily "wastes" these comparisons, but helps the compiler partially unroll the loop.
1 parent f27cafb commit 41edcb9

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

libc-top-half/musl/src/string/memcmp.c

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,42 @@
11
#include <string.h>
22

3+
#ifdef __wasm_simd128__
4+
#include <wasm_simd128.h>
5+
#endif
6+
37
int memcmp(const void *vl, const void *vr, size_t n)
48
{
9+
#if defined(__wasm_simd128__) && defined(__wasilibc_simd_string)
10+
if (n >= sizeof(v128_t)) {
11+
// memcmp is allowed to read up to n bytes from each object.
12+
// Find the first different character in the objects.
13+
// Unaligned loads handle the case where the objects
14+
// have mismatching alignments.
15+
const v128_t *v1 = (v128_t *)vl;
16+
const v128_t *v2 = (v128_t *)vr;
17+
while (n) {
18+
const v128_t cmp = wasm_i8x16_eq(wasm_v128_load(v1), wasm_v128_load(v2));
19+
// Bitmask is slow on AArch64, all_true is much faster.
20+
if (!wasm_i8x16_all_true(cmp)) {
21+
// Find the offset of the first zero bit (little-endian).
22+
size_t ctz = __builtin_ctz(~wasm_i8x16_bitmask(cmp));
23+
const unsigned char *u1 = (unsigned char *)v1 + ctz;
24+
const unsigned char *u2 = (unsigned char *)v2 + ctz;
25+
// This may help the compiler if the function is inlined.
26+
__builtin_assume(*u1 - *u2 != 0);
27+
return *u1 - *u2;
28+
}
29+
// This makes n a multiple of sizeof(v128_t)
30+
// for every iteration except the first.
31+
size_t align = (n - 1) % sizeof(v128_t) + 1;
32+
v1 = (v128_t *)((char *)v1 + align);
33+
v2 = (v128_t *)((char *)v2 + align);
34+
n -= align;
35+
}
36+
return 0;
37+
}
38+
#endif
39+
540
const unsigned char *l=vl, *r=vr;
641
for (; n && *l == *r; n--, l++, r++);
742
return n ? *l-*r : 0;

test/src/misc/memcmp.c

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680
2+
3+
#include <__macro_PAGESIZE.h>
4+
#include <stddef.h>
5+
#include <stdio.h>
6+
#include <string.h>
7+
8+
int sign(int val) {
9+
return (0 < val) - (val < 0);
10+
}
11+
void test(char *ptr1, char *ptr2, size_t length, int want) {
12+
int got = memcmp(ptr1, ptr2, length);
13+
if (sign(got) != sign(want)) {
14+
printf("memcmp(%p, %p, %lu) = %d, want %d\n", ptr1, ptr2, length, got,
15+
want);
16+
}
17+
}
18+
19+
int main(void) {
20+
char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE);
21+
22+
for (ptrdiff_t length = 0; length < 64; length++) {
23+
for (ptrdiff_t alignment = 0; alignment < 24; alignment++) {
24+
for (ptrdiff_t pos = -2; pos < length + 2; pos++) {
25+
// Create a buffer with the given length, at a pointer with the given
26+
// alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers
27+
// will straddle a (Wasm, and likely OS) page boundary.
28+
// The second buffer has a fixed address, which means it won't
29+
// always share alignment with first buffer.
30+
// Place the difference to find at every position in the buffers,
31+
// including just prior to it and after its end.
32+
char *ptr1 = LIMIT - PAGESIZE - 8 + alignment;
33+
char *ptr2 = LIMIT - PAGESIZE / 2;
34+
memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE);
35+
memset(ptr1, 5, length);
36+
memset(ptr2, 5, length);
37+
38+
ptr1[pos] = 7;
39+
ptr2[pos] = 3;
40+
41+
test(ptr1, ptr2, length,
42+
0 <= pos && pos < length ? ptr1[pos] - ptr2[pos] : 0);
43+
test(ptr2, ptr1, length,
44+
0 <= pos && pos < length ? ptr2[pos] - ptr1[pos] : 0);
45+
}
46+
}
47+
}
48+
49+
return 0;
50+
}

0 commit comments

Comments
 (0)