Skip to content

Commit 5c5bc36

Browse files
author
Amrit kumar Mahto
committed
fix: add NULL checks to Rust FFI exports
1 parent 477307e commit 5c5bc36

File tree

2 files changed

+76
-28
lines changed

2 files changed

+76
-28
lines changed

src/rust/src/libccxr_exports/bitstream.rs

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ pub unsafe extern "C" fn ccxr_free_bitstream(bs: *mut BitStreamRust<'static>) {
120120
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_internal_state_from_rust_to_c`
121121
#[no_mangle]
122122
pub unsafe extern "C" fn ccxr_next_bits(bs: *mut bitstream, bnum: u32) -> u64 {
123+
if bs.is_null() {
124+
return 0;
125+
}
123126
let mut rust_bs = copy_bitstream_c_to_rust(bs);
124127
let val = match rust_bs.next_bits(bnum) {
125128
Ok(val) => val,
@@ -136,6 +139,9 @@ pub unsafe extern "C" fn ccxr_next_bits(bs: *mut bitstream, bnum: u32) -> u64 {
136139
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
137140
#[no_mangle]
138141
pub unsafe extern "C" fn ccxr_read_bits(bs: *mut bitstream, bnum: u32) -> u64 {
142+
if bs.is_null() {
143+
return 0;
144+
}
139145
let mut rust_bs = copy_bitstream_c_to_rust(bs);
140146
let val = match rust_bs.read_bits(bnum) {
141147
Ok(val) => val,
@@ -151,6 +157,9 @@ pub unsafe extern "C" fn ccxr_read_bits(bs: *mut bitstream, bnum: u32) -> u64 {
151157
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
152158
#[no_mangle]
153159
pub unsafe extern "C" fn ccxr_skip_bits(bs: *mut bitstream, bnum: u32) -> i32 {
160+
if bs.is_null() {
161+
return 0;
162+
}
154163
let mut rust_bs = copy_bitstream_c_to_rust(bs);
155164
let val = match rust_bs.skip_bits(bnum) {
156165
Ok(val) => val,
@@ -170,6 +179,9 @@ pub unsafe extern "C" fn ccxr_skip_bits(bs: *mut bitstream, bnum: u32) -> i32 {
170179
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
171180
#[no_mangle]
172181
pub unsafe extern "C" fn ccxr_is_byte_aligned(bs: *mut bitstream) -> i32 {
182+
if bs.is_null() {
183+
return 0;
184+
}
173185
let rust_bs = copy_bitstream_c_to_rust(bs);
174186
match rust_bs.is_byte_aligned() {
175187
Ok(val) => {
@@ -189,6 +201,9 @@ pub unsafe extern "C" fn ccxr_is_byte_aligned(bs: *mut bitstream) -> i32 {
189201
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
190202
#[no_mangle]
191203
pub unsafe extern "C" fn ccxr_make_byte_aligned(bs: *mut bitstream) {
204+
if bs.is_null() {
205+
return;
206+
}
192207
let mut rust_bs = copy_bitstream_c_to_rust(bs);
193208
if rust_bs.make_byte_aligned().is_ok() {
194209
copy_bitstream_from_rust_to_c(bs, &rust_bs);
@@ -203,6 +218,9 @@ pub unsafe extern "C" fn ccxr_make_byte_aligned(bs: *mut bitstream) {
203218
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_internal_state_from_rust_to_c`
204219
#[no_mangle]
205220
pub unsafe extern "C" fn ccxr_next_bytes(bs: *mut bitstream, bynum: usize) -> *const u8 {
221+
if bs.is_null() {
222+
return std::ptr::null();
223+
}
206224
let mut rust_bs = copy_bitstream_c_to_rust(bs);
207225
match rust_bs.next_bytes(bynum) {
208226
Ok(slice) => {
@@ -220,6 +238,9 @@ pub unsafe extern "C" fn ccxr_next_bytes(bs: *mut bitstream, bynum: usize) -> *c
220238
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
221239
#[no_mangle]
222240
pub unsafe extern "C" fn ccxr_read_bytes(bs: *mut bitstream, bynum: usize) -> *const u8 {
241+
if bs.is_null() {
242+
return std::ptr::null();
243+
}
223244
let mut rust_bs = copy_bitstream_c_to_rust(bs);
224245
match rust_bs.read_bytes(bynum) {
225246
Ok(slice) => {
@@ -239,6 +260,9 @@ pub unsafe extern "C" fn ccxr_bitstream_get_num(
239260
bytes: usize,
240261
advance: i32,
241262
) -> u64 {
263+
if bs.is_null() {
264+
return 0;
265+
}
242266
let mut rust_bs = copy_bitstream_c_to_rust(bs);
243267
let result = rust_bs.bitstream_get_num(bytes, advance != 0).unwrap_or(0);
244268
copy_bitstream_from_rust_to_c(bs, &rust_bs);
@@ -251,6 +275,9 @@ pub unsafe extern "C" fn ccxr_bitstream_get_num(
251275
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
252276
#[no_mangle]
253277
pub unsafe extern "C" fn ccxr_read_exp_golomb_unsigned(bs: *mut bitstream) -> u64 {
278+
if bs.is_null() {
279+
return 0;
280+
}
254281
let mut rust_bs = copy_bitstream_c_to_rust(bs);
255282
let result = rust_bs.read_exp_golomb_unsigned().unwrap_or(0);
256283
copy_bitstream_from_rust_to_c(bs, &rust_bs);
@@ -263,6 +290,9 @@ pub unsafe extern "C" fn ccxr_read_exp_golomb_unsigned(bs: *mut bitstream) -> u6
263290
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
264291
#[no_mangle]
265292
pub unsafe extern "C" fn ccxr_read_exp_golomb(bs: *mut bitstream) -> i64 {
293+
if bs.is_null() {
294+
return 0;
295+
}
266296
let mut rust_bs = copy_bitstream_c_to_rust(bs);
267297
let result = rust_bs.read_exp_golomb().unwrap_or(0);
268298
copy_bitstream_from_rust_to_c(bs, &rust_bs);
@@ -274,6 +304,9 @@ pub unsafe extern "C" fn ccxr_read_exp_golomb(bs: *mut bitstream) -> i64 {
274304
/// This function is unsafe because it calls unsafe functions `copy_bitstream_c_to_rust` and `copy_bitstream_from_rust_to_c`
275305
#[no_mangle]
276306
pub unsafe extern "C" fn ccxr_read_int(bs: *mut bitstream, bnum: u32) -> i64 {
307+
if bs.is_null() {
308+
return 0;
309+
}
277310
let mut rust_bs = copy_bitstream_c_to_rust(bs);
278311
let result = rust_bs.read_int(bnum).unwrap_or(0);
279312
copy_bitstream_from_rust_to_c(bs, &rust_bs);
@@ -804,35 +837,20 @@ mod bitstream_copying_tests {
804837
}
805838

806839
#[test]
807-
fn test_memory_safety() {
808-
let buffer = create_test_buffer(1000);
809-
let rust_stream = BitStreamRust {
810-
data: &buffer,
811-
pos: 0,
812-
bpos: 0,
813-
bits_left: 8000,
814-
error: false,
815-
_i_pos: 500,
816-
_i_bpos: 0,
817-
};
818-
840+
fn test_ffi_safety() {
819841
unsafe {
820-
let c_s = Box::into_raw(Box::new(bitstream::default()));
821-
copy_bitstream_from_rust_to_c(c_s, &rust_stream);
822-
let c_stream = &mut *c_s;
823-
824-
// Verify all pointers are within bounds
825-
assert!(verify_pointer_bounds(c_stream));
826-
827-
// Verify we can safely access the boundaries
828-
let first_byte = *c_stream.pos;
829-
let last_byte = *c_stream.end.sub(1);
830-
let internal_byte = *c_stream._i_pos;
831-
832-
// These should not panic and should match our buffer
833-
assert_eq!(first_byte, 0);
834-
assert_eq!(last_byte, (999 % 256) as u8);
835-
assert_eq!(internal_byte, (500 % 256) as u8);
842+
// Test NULL pointer safety
843+
assert_eq!(super::ccxr_next_bits(std::ptr::null_mut(), 8), 0);
844+
assert_eq!(super::ccxr_read_bits(std::ptr::null_mut(), 8), 0);
845+
assert_eq!(super::ccxr_skip_bits(std::ptr::null_mut(), 8), 0);
846+
assert_eq!(super::ccxr_is_byte_aligned(std::ptr::null_mut()), 0);
847+
super::ccxr_make_byte_aligned(std::ptr::null_mut()); // Should not panic
848+
assert!(super::ccxr_next_bytes(std::ptr::null_mut(), 1).is_null());
849+
assert!(super::ccxr_read_bytes(std::ptr::null_mut(), 1).is_null());
850+
assert_eq!(super::ccxr_bitstream_get_num(std::ptr::null_mut(), 1, 0), 0);
851+
assert_eq!(super::ccxr_read_exp_golomb_unsigned(std::ptr::null_mut()), 0);
852+
assert_eq!(super::ccxr_read_exp_golomb(std::ptr::null_mut()), 0);
853+
assert_eq!(super::ccxr_read_int(std::ptr::null_mut(), 8), 0);
836854
}
837855
}
838856
}

src/rust/src/libccxr_exports/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ pub unsafe extern "C" fn ccxr_update_logger_target() {
7777
/// or less than `len`.
7878
#[no_mangle]
7979
pub unsafe extern "C" fn ccxr_verify_crc32(buf: *const u8, len: c_int) -> c_int {
80+
if buf.is_null() || len < 0 {
81+
return 0;
82+
}
8083
let buf = std::slice::from_raw_parts(buf, len as usize);
8184
if verify_crc32(buf) {
8285
1
@@ -97,6 +100,9 @@ pub unsafe extern "C" fn ccxr_levenshtein_dist(
97100
s1len: c_uint,
98101
s2len: c_uint,
99102
) -> c_int {
103+
if s1.is_null() || s2.is_null() {
104+
return 0;
105+
}
100106
let s1 = std::slice::from_raw_parts(s1, s1len as usize);
101107
let s2 = std::slice::from_raw_parts(s2, s2len as usize);
102108

@@ -118,10 +124,34 @@ pub unsafe extern "C" fn ccxr_levenshtein_dist_char(
118124
s1len: c_uint,
119125
s2len: c_uint,
120126
) -> c_int {
127+
if s1.is_null() || s2.is_null() {
128+
return 0;
129+
}
121130
let s1 = std::slice::from_raw_parts(s1, s1len as usize);
122131
let s2 = std::slice::from_raw_parts(s2, s2len as usize);
123132

124133
let ans = levenshtein_dist_char(s1, s2);
125134

126135
ans.min(c_int::MAX as usize) as c_int
127136
}
137+
138+
#[cfg(test)]
139+
mod tests {
140+
use super::*;
141+
use std::ptr;
142+
143+
#[test]
144+
fn test_ffi_safety() {
145+
unsafe {
146+
// Test NULL pointer and negative length safety for CRC32
147+
assert_eq!(ccxr_verify_crc32(ptr::null(), 10), 0);
148+
assert_eq!(ccxr_verify_crc32(ptr::null(), -1), 0);
149+
let buf = [1, 2, 3];
150+
assert_eq!(ccxr_verify_crc32(buf.as_ptr(), -1), 0);
151+
152+
// Test NULL pointer safety for Levenshtein
153+
assert_eq!(ccxr_levenshtein_dist(ptr::null(), ptr::null(), 0, 0), 0);
154+
assert_eq!(ccxr_levenshtein_dist_char(ptr::null(), ptr::null(), 0, 0), 0);
155+
}
156+
}
157+
}

0 commit comments

Comments
 (0)