|
1 | 1 | #![allow(dead_code)] |
2 | 2 |
|
| 3 | +use std::slice; |
| 4 | + |
3 | 5 | #[inline] |
4 | | -pub unsafe fn memcmp0(_: &[u8], _: &[u8]) -> bool { |
| 6 | +pub unsafe fn memcmp0(_: *const u8, _: *const u8, n: usize) -> bool { |
| 7 | + debug_assert_eq!(n, 0); |
5 | 8 | true |
6 | 9 | } |
7 | 10 |
|
8 | 11 | #[inline] |
9 | | -pub unsafe fn memcmp1(left: &[u8], right: &[u8]) -> bool { |
| 12 | +pub unsafe fn memcmp1(left: *const u8, right: *const u8, n: usize) -> bool { |
| 13 | + debug_assert_eq!(n, 1); |
10 | 14 | *left == *right |
11 | 15 | } |
12 | 16 |
|
13 | 17 | #[inline] |
14 | | -pub unsafe fn memcmp2(left: &[u8], right: &[u8]) -> bool { |
15 | | - let left = left.as_ptr().cast::<u16>(); |
16 | | - let right = right.as_ptr().cast::<u16>(); |
| 18 | +pub unsafe fn memcmp2(left: *const u8, right: *const u8, n: usize) -> bool { |
| 19 | + debug_assert_eq!(n, 2); |
| 20 | + let left = left.cast::<u16>(); |
| 21 | + let right = right.cast::<u16>(); |
17 | 22 | *left == *right |
18 | 23 | } |
19 | 24 |
|
20 | 25 | #[inline] |
21 | | -pub unsafe fn memcmp3(left: &[u8], right: &[u8]) -> bool { |
22 | | - let left = left.as_ptr().cast::<u32>(); |
23 | | - let right = right.as_ptr().cast::<u32>(); |
24 | | - (*left & 0x00ffffff) == (*right & 0x00ffffff) |
| 26 | +pub unsafe fn memcmp3(left: *const u8, right: *const u8, n: usize) -> bool { |
| 27 | + debug_assert_eq!(n, 3); |
| 28 | + memcmp2(left, right, 2) && memcmp1(left.add(2), right.add(2), 1) |
25 | 29 | } |
26 | 30 |
|
27 | 31 | #[inline] |
28 | | -pub unsafe fn memcmp4(left: &[u8], right: &[u8]) -> bool { |
29 | | - let left = left.as_ptr().cast::<u32>(); |
30 | | - let right = right.as_ptr().cast::<u32>(); |
| 32 | +pub unsafe fn memcmp4(left: *const u8, right: *const u8, n: usize) -> bool { |
| 33 | + debug_assert_eq!(n, 4); |
| 34 | + let left = left.cast::<u32>(); |
| 35 | + let right = right.cast::<u32>(); |
31 | 36 | *left == *right |
32 | 37 | } |
33 | 38 |
|
34 | 39 | #[inline] |
35 | | -pub unsafe fn memcmp5(left: &[u8], right: &[u8]) -> bool { |
36 | | - let left = left.as_ptr().cast::<u64>(); |
37 | | - let right = right.as_ptr().cast::<u64>(); |
38 | | - (*left ^ *right).trailing_zeros() >= 40 |
| 40 | +pub unsafe fn memcmp5(left: *const u8, right: *const u8, n: usize) -> bool { |
| 41 | + debug_assert_eq!(n, 5); |
| 42 | + memcmp4(left, right, 4) && memcmp1(left.add(4), right.add(4), 1) |
39 | 43 | } |
40 | 44 |
|
41 | 45 | #[inline] |
42 | | -pub unsafe fn memcmp6(left: &[u8], right: &[u8]) -> bool { |
43 | | - let left = left.as_ptr().cast::<u64>(); |
44 | | - let right = right.as_ptr().cast::<u64>(); |
45 | | - (*left ^ *right).trailing_zeros() >= 48 |
| 46 | +pub unsafe fn memcmp6(left: *const u8, right: *const u8, n: usize) -> bool { |
| 47 | + debug_assert_eq!(n, 6); |
| 48 | + memcmp4(left, right, 4) && memcmp2(left.add(4), right.add(4), 2) |
46 | 49 | } |
47 | 50 |
|
48 | | -#[allow(dead_code)] |
49 | 51 | #[inline] |
50 | | -pub unsafe fn memcmp7(left: &[u8], right: &[u8]) -> bool { |
51 | | - let left = left.as_ptr().cast::<u64>(); |
52 | | - let right = right.as_ptr().cast::<u64>(); |
53 | | - (*left ^ *right).trailing_zeros() >= 56 |
| 52 | +pub unsafe fn memcmp7(left: *const u8, right: *const u8, n: usize) -> bool { |
| 53 | + debug_assert_eq!(n, 7); |
| 54 | + memcmp4(left, right, 4) && memcmp3(left.add(4), right.add(4), 3) |
54 | 55 | } |
55 | 56 |
|
56 | 57 | #[inline] |
57 | | -pub unsafe fn memcmp8(left: &[u8], right: &[u8]) -> bool { |
58 | | - let left = left.as_ptr().cast::<u64>(); |
59 | | - let right = right.as_ptr().cast::<u64>(); |
| 58 | +pub unsafe fn memcmp8(left: *const u8, right: *const u8, n: usize) -> bool { |
| 59 | + debug_assert_eq!(n, 8); |
| 60 | + let left = left.cast::<u64>(); |
| 61 | + let right = right.cast::<u64>(); |
60 | 62 | *left == *right |
61 | 63 | } |
62 | 64 |
|
63 | 65 | #[inline] |
64 | | -pub unsafe fn memcmp9(left: &[u8], right: &[u8]) -> bool { |
65 | | - let left_first = left.as_ptr().cast::<u64>(); |
66 | | - let right_first = right.as_ptr().cast::<u64>(); |
67 | | - *left_first == *right_first && *left.as_ptr().add(8) == *right.as_ptr().add(8) |
| 66 | +pub unsafe fn memcmp9(left: *const u8, right: *const u8, n: usize) -> bool { |
| 67 | + debug_assert_eq!(n, 9); |
| 68 | + memcmp8(left, right, 8) && memcmp1(left.add(8), right.add(8), 1) |
68 | 69 | } |
69 | 70 |
|
70 | 71 | #[inline] |
71 | | -pub unsafe fn memcmp10(left: &[u8], right: &[u8]) -> bool { |
72 | | - let left_first = left.as_ptr().cast::<u64>(); |
73 | | - let right_first = right.as_ptr().cast::<u64>(); |
74 | | - let left_second = left.as_ptr().add(8).cast::<u16>(); |
75 | | - let right_second = right.as_ptr().add(8).cast::<u16>(); |
76 | | - *left_first == *right_first && *left_second == *right_second |
| 72 | +pub unsafe fn memcmp10(left: *const u8, right: *const u8, n: usize) -> bool { |
| 73 | + debug_assert_eq!(n, 10); |
| 74 | + memcmp8(left, right, 8) && memcmp2(left.add(8), right.add(8), 2) |
77 | 75 | } |
78 | 76 |
|
79 | 77 | #[inline] |
80 | | -pub unsafe fn memcmp11(left: &[u8], right: &[u8]) -> bool { |
81 | | - let left_first = left.as_ptr().cast::<u64>(); |
82 | | - let right_first = right.as_ptr().cast::<u64>(); |
83 | | - let left_second = left.as_ptr().add(8).cast::<u32>(); |
84 | | - let right_second = right.as_ptr().add(8).cast::<u32>(); |
85 | | - *left_first == *right_first && (*left_second & 0x00ffffff) == (*right_second & 0x00ffffff) |
| 78 | +pub unsafe fn memcmp11(left: *const u8, right: *const u8, n: usize) -> bool { |
| 79 | + debug_assert_eq!(n, 11); |
| 80 | + memcmp8(left, right, 8) && memcmp3(left.add(8), right.add(8), 3) |
86 | 81 | } |
87 | 82 |
|
88 | 83 | #[inline] |
89 | | -pub unsafe fn memcmp12(left: &[u8], right: &[u8]) -> bool { |
90 | | - let left_first = left.as_ptr().cast::<u64>(); |
91 | | - let right_first = right.as_ptr().cast::<u64>(); |
92 | | - let left_second = left.as_ptr().add(8).cast::<u32>(); |
93 | | - let right_second = right.as_ptr().add(8).cast::<u32>(); |
94 | | - *left_first == *right_first && *left_second == *right_second |
| 84 | +pub unsafe fn memcmp12(left: *const u8, right: *const u8, n: usize) -> bool { |
| 85 | + debug_assert_eq!(n, 12); |
| 86 | + memcmp8(left, right, 8) && memcmp4(left.add(8), right.add(8), 4) |
95 | 87 | } |
96 | 88 |
|
97 | 89 | #[inline] |
98 | | -pub unsafe fn memcmp(left: &[u8], right: &[u8]) -> bool { |
99 | | - left == right |
| 90 | +pub unsafe fn memcmp(left: *const u8, right: *const u8, n: usize) -> bool { |
| 91 | + slice::from_raw_parts(left, n) == slice::from_raw_parts(right, n) |
| 92 | +} |
| 93 | + |
| 94 | +#[cfg(test)] |
| 95 | +mod tests { |
| 96 | + fn memcmp(f: unsafe fn(*const u8, *const u8, usize) -> bool, n: usize) { |
| 97 | + let left = vec![b'0'; n]; |
| 98 | + unsafe { assert!(f(left.as_ptr(), left.as_ptr(), n)) }; |
| 99 | + unsafe { assert!(super::memcmp(left.as_ptr(), left.as_ptr(), n)) }; |
| 100 | + |
| 101 | + for i in 0..n { |
| 102 | + let mut right = left.clone(); |
| 103 | + right[i] = b'1'; |
| 104 | + unsafe { assert!(!f(left.as_ptr(), right.as_ptr(), n)) }; |
| 105 | + unsafe { assert!(!super::memcmp(left.as_ptr(), right.as_ptr(), n)) }; |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + #[test] |
| 110 | + fn memcmp0() { |
| 111 | + memcmp(super::memcmp0, 0); |
| 112 | + } |
| 113 | + |
| 114 | + #[test] |
| 115 | + fn memcmp1() { |
| 116 | + memcmp(super::memcmp1, 1); |
| 117 | + } |
| 118 | + |
| 119 | + #[test] |
| 120 | + fn memcmp2() { |
| 121 | + memcmp(super::memcmp2, 2); |
| 122 | + } |
| 123 | + |
| 124 | + #[test] |
| 125 | + fn memcmp3() { |
| 126 | + memcmp(super::memcmp3, 3); |
| 127 | + } |
| 128 | + |
| 129 | + #[test] |
| 130 | + fn memcmp4() { |
| 131 | + memcmp(super::memcmp4, 4); |
| 132 | + } |
| 133 | + |
| 134 | + #[test] |
| 135 | + fn memcmp5() { |
| 136 | + memcmp(super::memcmp5, 5); |
| 137 | + } |
| 138 | + |
| 139 | + #[test] |
| 140 | + fn memcmp6() { |
| 141 | + memcmp(super::memcmp6, 6); |
| 142 | + } |
| 143 | + |
| 144 | + #[test] |
| 145 | + fn memcmp7() { |
| 146 | + memcmp(super::memcmp7, 7); |
| 147 | + } |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn memcmp8() { |
| 151 | + memcmp(super::memcmp8, 8); |
| 152 | + } |
| 153 | + |
| 154 | + #[test] |
| 155 | + fn memcmp9() { |
| 156 | + memcmp(super::memcmp9, 9); |
| 157 | + } |
| 158 | + |
| 159 | + #[test] |
| 160 | + fn memcmp10() { |
| 161 | + memcmp(super::memcmp10, 10); |
| 162 | + } |
| 163 | + |
| 164 | + #[test] |
| 165 | + fn memcmp11() { |
| 166 | + memcmp(super::memcmp11, 11); |
| 167 | + } |
| 168 | + |
| 169 | + #[test] |
| 170 | + fn memcmp12() { |
| 171 | + memcmp(super::memcmp12, 12); |
| 172 | + } |
100 | 173 | } |
0 commit comments