Skip to content

Commit bb23ba4

Browse files
committed
Fix memory overflow bug in memcmp implementation
1 parent 6c3abf5 commit bb23ba4

File tree

4 files changed

+140
-70
lines changed

4 files changed

+140
-70
lines changed

src/memcmp.rs

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,173 @@
11
#![allow(dead_code)]
22

3+
use std::slice;
4+
35
#[inline]
4-
pub unsafe fn memcmp0(_: &[u8], _: &[u8]) -> bool {
6+
pub unsafe fn memcmp0(_: *const u8, _: *const u8, n: usize) -> bool {
7+
debug_assert_eq!(n, 0);
58
true
69
}
710

811
#[inline]
9-
pub unsafe fn memcmp1(left: &[u8], right: &[u8]) -> bool {
12+
pub unsafe fn memcmp1(left: *const u8, right: *const u8, n: usize) -> bool {
13+
debug_assert_eq!(n, 1);
1014
*left == *right
1115
}
1216

1317
#[inline]
14-
pub unsafe fn memcmp2(left: &[u8], right: &[u8]) -> bool {
15-
let left = left.as_ptr().cast::<u16>();
16-
let right = right.as_ptr().cast::<u16>();
18+
pub unsafe fn memcmp2(left: *const u8, right: *const u8, n: usize) -> bool {
19+
debug_assert_eq!(n, 2);
20+
let left = left.cast::<u16>();
21+
let right = right.cast::<u16>();
1722
*left == *right
1823
}
1924

2025
#[inline]
21-
pub unsafe fn memcmp3(left: &[u8], right: &[u8]) -> bool {
22-
let left = left.as_ptr().cast::<u32>();
23-
let right = right.as_ptr().cast::<u32>();
24-
(*left & 0x00ffffff) == (*right & 0x00ffffff)
26+
pub unsafe fn memcmp3(left: *const u8, right: *const u8, n: usize) -> bool {
27+
debug_assert_eq!(n, 3);
28+
memcmp2(left, right, 2) && memcmp1(left.add(2), right.add(2), 1)
2529
}
2630

2731
#[inline]
28-
pub unsafe fn memcmp4(left: &[u8], right: &[u8]) -> bool {
29-
let left = left.as_ptr().cast::<u32>();
30-
let right = right.as_ptr().cast::<u32>();
32+
pub unsafe fn memcmp4(left: *const u8, right: *const u8, n: usize) -> bool {
33+
debug_assert_eq!(n, 4);
34+
let left = left.cast::<u32>();
35+
let right = right.cast::<u32>();
3136
*left == *right
3237
}
3338

3439
#[inline]
35-
pub unsafe fn memcmp5(left: &[u8], right: &[u8]) -> bool {
36-
let left = left.as_ptr().cast::<u64>();
37-
let right = right.as_ptr().cast::<u64>();
38-
(*left ^ *right).trailing_zeros() >= 40
40+
pub unsafe fn memcmp5(left: *const u8, right: *const u8, n: usize) -> bool {
41+
debug_assert_eq!(n, 5);
42+
memcmp4(left, right, 4) && memcmp1(left.add(4), right.add(4), 1)
3943
}
4044

4145
#[inline]
42-
pub unsafe fn memcmp6(left: &[u8], right: &[u8]) -> bool {
43-
let left = left.as_ptr().cast::<u64>();
44-
let right = right.as_ptr().cast::<u64>();
45-
(*left ^ *right).trailing_zeros() >= 48
46+
pub unsafe fn memcmp6(left: *const u8, right: *const u8, n: usize) -> bool {
47+
debug_assert_eq!(n, 6);
48+
memcmp4(left, right, 4) && memcmp2(left.add(4), right.add(4), 2)
4649
}
4750

48-
#[allow(dead_code)]
4951
#[inline]
50-
pub unsafe fn memcmp7(left: &[u8], right: &[u8]) -> bool {
51-
let left = left.as_ptr().cast::<u64>();
52-
let right = right.as_ptr().cast::<u64>();
53-
(*left ^ *right).trailing_zeros() >= 56
52+
pub unsafe fn memcmp7(left: *const u8, right: *const u8, n: usize) -> bool {
53+
debug_assert_eq!(n, 7);
54+
memcmp4(left, right, 4) && memcmp3(left.add(4), right.add(4), 3)
5455
}
5556

5657
#[inline]
57-
pub unsafe fn memcmp8(left: &[u8], right: &[u8]) -> bool {
58-
let left = left.as_ptr().cast::<u64>();
59-
let right = right.as_ptr().cast::<u64>();
58+
pub unsafe fn memcmp8(left: *const u8, right: *const u8, n: usize) -> bool {
59+
debug_assert_eq!(n, 8);
60+
let left = left.cast::<u64>();
61+
let right = right.cast::<u64>();
6062
*left == *right
6163
}
6264

6365
#[inline]
64-
pub unsafe fn memcmp9(left: &[u8], right: &[u8]) -> bool {
65-
let left_first = left.as_ptr().cast::<u64>();
66-
let right_first = right.as_ptr().cast::<u64>();
67-
*left_first == *right_first && *left.as_ptr().add(8) == *right.as_ptr().add(8)
66+
pub unsafe fn memcmp9(left: *const u8, right: *const u8, n: usize) -> bool {
67+
debug_assert_eq!(n, 9);
68+
memcmp8(left, right, 8) && memcmp1(left.add(8), right.add(8), 1)
6869
}
6970

7071
#[inline]
71-
pub unsafe fn memcmp10(left: &[u8], right: &[u8]) -> bool {
72-
let left_first = left.as_ptr().cast::<u64>();
73-
let right_first = right.as_ptr().cast::<u64>();
74-
let left_second = left.as_ptr().add(8).cast::<u16>();
75-
let right_second = right.as_ptr().add(8).cast::<u16>();
76-
*left_first == *right_first && *left_second == *right_second
72+
pub unsafe fn memcmp10(left: *const u8, right: *const u8, n: usize) -> bool {
73+
debug_assert_eq!(n, 10);
74+
memcmp8(left, right, 8) && memcmp2(left.add(8), right.add(8), 2)
7775
}
7876

7977
#[inline]
80-
pub unsafe fn memcmp11(left: &[u8], right: &[u8]) -> bool {
81-
let left_first = left.as_ptr().cast::<u64>();
82-
let right_first = right.as_ptr().cast::<u64>();
83-
let left_second = left.as_ptr().add(8).cast::<u32>();
84-
let right_second = right.as_ptr().add(8).cast::<u32>();
85-
*left_first == *right_first && (*left_second & 0x00ffffff) == (*right_second & 0x00ffffff)
78+
pub unsafe fn memcmp11(left: *const u8, right: *const u8, n: usize) -> bool {
79+
debug_assert_eq!(n, 11);
80+
memcmp8(left, right, 8) && memcmp3(left.add(8), right.add(8), 3)
8681
}
8782

8883
#[inline]
89-
pub unsafe fn memcmp12(left: &[u8], right: &[u8]) -> bool {
90-
let left_first = left.as_ptr().cast::<u64>();
91-
let right_first = right.as_ptr().cast::<u64>();
92-
let left_second = left.as_ptr().add(8).cast::<u32>();
93-
let right_second = right.as_ptr().add(8).cast::<u32>();
94-
*left_first == *right_first && *left_second == *right_second
84+
pub unsafe fn memcmp12(left: *const u8, right: *const u8, n: usize) -> bool {
85+
debug_assert_eq!(n, 12);
86+
memcmp8(left, right, 8) && memcmp4(left.add(8), right.add(8), 4)
9587
}
9688

9789
#[inline]
98-
pub unsafe fn memcmp(left: &[u8], right: &[u8]) -> bool {
99-
left == right
90+
pub unsafe fn memcmp(left: *const u8, right: *const u8, n: usize) -> bool {
91+
slice::from_raw_parts(left, n) == slice::from_raw_parts(right, n)
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
fn memcmp(f: unsafe fn(*const u8, *const u8, usize) -> bool, n: usize) {
97+
let left = vec![b'0'; n];
98+
unsafe { assert!(f(left.as_ptr(), left.as_ptr(), n)) };
99+
unsafe { assert!(super::memcmp(left.as_ptr(), left.as_ptr(), n)) };
100+
101+
for i in 0..n {
102+
let mut right = left.clone();
103+
right[i] = b'1';
104+
unsafe { assert!(!f(left.as_ptr(), right.as_ptr(), n)) };
105+
unsafe { assert!(!super::memcmp(left.as_ptr(), right.as_ptr(), n)) };
106+
}
107+
}
108+
109+
#[test]
110+
fn memcmp0() {
111+
memcmp(super::memcmp0, 0);
112+
}
113+
114+
#[test]
115+
fn memcmp1() {
116+
memcmp(super::memcmp1, 1);
117+
}
118+
119+
#[test]
120+
fn memcmp2() {
121+
memcmp(super::memcmp2, 2);
122+
}
123+
124+
#[test]
125+
fn memcmp3() {
126+
memcmp(super::memcmp3, 3);
127+
}
128+
129+
#[test]
130+
fn memcmp4() {
131+
memcmp(super::memcmp4, 4);
132+
}
133+
134+
#[test]
135+
fn memcmp5() {
136+
memcmp(super::memcmp5, 5);
137+
}
138+
139+
#[test]
140+
fn memcmp6() {
141+
memcmp(super::memcmp6, 6);
142+
}
143+
144+
#[test]
145+
fn memcmp7() {
146+
memcmp(super::memcmp7, 7);
147+
}
148+
149+
#[test]
150+
fn memcmp8() {
151+
memcmp(super::memcmp8, 8);
152+
}
153+
154+
#[test]
155+
fn memcmp9() {
156+
memcmp(super::memcmp9, 9);
157+
}
158+
159+
#[test]
160+
fn memcmp10() {
161+
memcmp(super::memcmp10, 10);
162+
}
163+
164+
#[test]
165+
fn memcmp11() {
166+
memcmp(super::memcmp11, 11);
167+
}
168+
169+
#[test]
170+
fn memcmp12() {
171+
memcmp(super::memcmp12, 12);
172+
}
100173
}

src/x86/avx2/deprecated/original.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use crate::{bits::clear_leftmost_set, memcmp::*};
33
use std::arch::x86::*;
44
#[cfg(target_arch = "x86_64")]
55
use std::arch::x86_64::*;
6-
use std::slice::from_raw_parts;
76

87
#[inline]
98
#[target_feature(enable = "avx2")]
@@ -12,7 +11,7 @@ unsafe fn strstr_avx2_original_memcmp(
1211
n: usize,
1312
needle: *const u8,
1413
k: usize,
15-
memcmp: unsafe fn(&[u8], &[u8]) -> bool,
14+
memcmp: unsafe fn(*const u8, *const u8, usize) -> bool,
1615
) -> Option<usize> {
1716
let first = _mm256_set1_epi8(*needle as i8);
1817
let last = _mm256_set1_epi8(*needle.add(k - 1) as i8);
@@ -32,10 +31,7 @@ unsafe fn strstr_avx2_original_memcmp(
3231
while mask != 0 {
3332
let bitpos = mask.trailing_zeros() as usize;
3433
let startpos = i + bitpos + 1;
35-
if memcmp(
36-
from_raw_parts(haystack.add(startpos), k - 2),
37-
from_raw_parts(needle.add(1), k - 2),
38-
) {
34+
if memcmp(haystack.add(startpos), needle.add(1), k - 2) {
3935
return Some(i + bitpos);
4036
}
4137
mask = clear_leftmost_set(mask);
@@ -78,14 +74,12 @@ pub unsafe fn strstr_avx2_original(haystack: &[u8], needle: &[u8]) -> bool {
7874
needle.len(),
7975
memcmp2,
8076
),
81-
// Note: use memcmp4 rather memcmp3, as the last character of needle is already proven to be
82-
// equal
8377
5 => strstr_avx2_original_memcmp(
8478
haystack.as_ptr(),
8579
haystack.len(),
8680
needle.as_ptr(),
8781
needle.len(),
88-
memcmp4,
82+
memcmp3,
8983
),
9084
6 => strstr_avx2_original_memcmp(
9185
haystack.as_ptr(),
@@ -108,13 +102,12 @@ pub unsafe fn strstr_avx2_original(haystack: &[u8], needle: &[u8]) -> bool {
108102
needle.len(),
109103
memcmp6,
110104
),
111-
// Note: use memcmp8 rather memcmp7 for the same reason as above.
112105
9 => strstr_avx2_original_memcmp(
113106
haystack.as_ptr(),
114107
haystack.len(),
115108
needle.as_ptr(),
116109
needle.len(),
117-
memcmp8,
110+
memcmp7,
118111
),
119112
10 => strstr_avx2_original_memcmp(
120113
haystack.as_ptr(),

src/x86/avx2/deprecated/rust.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::arch::x86_64::*;
1111
unsafe fn strstr_avx2_rust_memcmp(
1212
haystack: &[u8],
1313
needle: &[u8],
14-
memcmp: unsafe fn(&[u8], &[u8]) -> bool,
14+
memcmp: unsafe fn(*const u8, *const u8, usize) -> bool,
1515
) -> bool {
1616
if haystack.len() < 32 {
1717
return strstr_rabin_karp(haystack, needle);
@@ -35,8 +35,9 @@ unsafe fn strstr_avx2_rust_memcmp(
3535
let startpos = i + bitpos;
3636
if startpos + needle.len() <= haystack.len()
3737
&& memcmp(
38-
&haystack[(startpos + 1)..(startpos + needle.len() - 1)],
39-
&needle[1..needle.len() - 1],
38+
haystack.as_ptr().add(startpos + 1),
39+
needle.as_ptr().add(1),
40+
needle.len() - 2,
4041
)
4142
{
4243
return true;
@@ -90,11 +91,11 @@ pub unsafe fn strstr_avx2_rust(haystack: &[u8], needle: &[u8]) -> bool {
9091
2 => strstr_avx2_rust_memcmp(haystack, needle, memcmp0),
9192
3 => strstr_avx2_rust_memcmp(haystack, needle, memcmp1),
9293
4 => strstr_avx2_rust_memcmp(haystack, needle, memcmp2),
93-
5 => strstr_avx2_rust_memcmp(haystack, needle, memcmp4),
94+
5 => strstr_avx2_rust_memcmp(haystack, needle, memcmp3),
9495
6 => strstr_avx2_rust_memcmp(haystack, needle, memcmp4),
9596
7 => strstr_avx2_rust_memcmp(haystack, needle, memcmp5),
9697
8 => strstr_avx2_rust_memcmp(haystack, needle, memcmp6),
97-
9 => strstr_avx2_rust_memcmp(haystack, needle, memcmp8),
98+
9 => strstr_avx2_rust_memcmp(haystack, needle, memcmp7),
9899
10 => strstr_avx2_rust_memcmp(haystack, needle, memcmp8),
99100
11 => strstr_avx2_rust_memcmp(haystack, needle, memcmp9),
100101
12 => strstr_avx2_rust_memcmp(haystack, needle, memcmp10),

src/x86/avx2/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,12 @@ macro_rules! avx2_searcher {
246246
let mut eq = (Vector::movemask_epi8(eq) & mask) as u32;
247247

248248
let start = start as usize - haystack.as_ptr() as usize;
249+
let chunk = haystack.as_ptr().add(start + 1);
250+
let needle = self.needle.as_ptr().add(1);
251+
249252
while eq != 0 {
250-
let chunk = &haystack[start + eq.trailing_zeros() as usize..];
251-
if $memcmp(&chunk[1..self.size()], &self.needle[1..]) {
253+
let chunk = chunk.add(eq.trailing_zeros() as usize);
254+
if $memcmp(chunk, needle, self.size() - 1) {
252255
return true;
253256
}
254257

0 commit comments

Comments
 (0)