diff --git a/src/rust/src/avc/core.rs b/src/rust/src/avc/core.rs index 4d8fd439e..e06526458 100644 --- a/src/rust/src/avc/core.rs +++ b/src/rust/src/avc/core.rs @@ -367,6 +367,31 @@ fn find_next_zero(slice: &[u8]) -> Option { fn find_next_zero(slice: &[u8]) -> Option { slice.iter().position(|&b| b == 0x00) } +/// Find the first NAL start code (0x00 0x00 0x01 or 0x00 0x00 0x00 0x01) in a buffer. +/// Returns the position of the 0x01 byte if found, or None if not found. +fn find_nal_start_code(buf: &[u8]) -> Option { + if buf.len() < 3 { + return None; + } + + for i in 0..buf.len().saturating_sub(2) { + // Check for 0x00 0x00 0x01 (3-byte start code) + if buf[i] == 0x00 && buf[i + 1] == 0x00 && buf[i + 2] == 0x01 { + return Some(i + 2); // Position of the 0x01 + } + // Also check for 0x00 0x00 0x00 0x01 (4-byte start code) + if i + 3 < buf.len() + && buf[i] == 0x00 + && buf[i + 1] == 0x00 + && buf[i + 2] == 0x00 + && buf[i + 3] == 0x01 + { + return Some(i + 3); // Position of the 0x01 + } + } + None +} + /// # Safety /// This function is unsafe because it dereferences raw pointers and calls `dump` and `do_nal`. pub unsafe fn process_avc( @@ -384,118 +409,155 @@ pub unsafe fn process_avc( )); } - // Warning there should be only leading zeros, nothing else - if !(avcbuf[0] == 0x00 && avcbuf[1] == 0x00) { - return Err(AvcError::BrokenStream( - "Leading bytes are non-zero".to_string(), - )); + // If the buffer doesn't start with leading zeros, try to find the first NAL start code. + // This can happen with: + // - HLS/Twitch stream segments that start mid-stream + // - Streams with garbage data at the beginning + // - Buffer accumulation issues after previous errors + let start_offset = if avcbuf[0] == 0x00 && avcbuf[1] == 0x00 { + // Normal case: buffer starts with zeros + 0 + } else { + // Try to find the first NAL start code + if let Some(nal_pos) = find_nal_start_code(avcbuf) { + // Found a NAL start code, skip to the position before it (the zeros) + // The position returned is the 0x01, so we need to go back to find the zeros + let zeros_start = if nal_pos >= 3 && avcbuf[nal_pos - 3] == 0x00 { + nal_pos - 3 // 4-byte start code + } else { + nal_pos - 2 // 3-byte start code + }; + debug!(msg_type = DebugMessageFlag::VERBOSE; + "Skipped {} bytes of garbage before first NAL start code", zeros_start); + zeros_start + } else { + // No NAL start code found - return full buffer length to clear it + debug!(msg_type = DebugMessageFlag::VERBOSE; + "No NAL start code found in buffer of {} bytes, clearing", avcbuflen); + return Ok(avcbuflen); + } + }; + + // Work with the buffer starting from start_offset + let working_buf = &avcbuf[start_offset..]; + let working_len = working_buf.len(); + + if working_len <= 5 { + // Not enough data after skipping garbage + return Ok(avcbuflen); } let mut buffer_position = 2usize; - let mut firstloop = true; // Loop over NAL units - while buffer_position < avcbuflen.saturating_sub(2) { + while buffer_position < working_len.saturating_sub(2) { let mut zeropad = 0; // Find next NAL_start - while buffer_position < avcbuflen { - if avcbuf[buffer_position] == 0x01 { + while buffer_position < working_len { + if working_buf[buffer_position] == 0x01 { break; - } else if firstloop && avcbuf[buffer_position] != 0x00 { - return Err(AvcError::BrokenStream( - "Leading bytes are non-zero".to_string(), - )); + } else if working_buf[buffer_position] != 0x00 { + // Non-zero byte found where we expected zeros - skip to next potential start code + if let Some(next_nal) = find_nal_start_code(&working_buf[buffer_position..]) { + buffer_position += next_nal - 1; // -1 because we'll increment at end of loop + zeropad = 0; + } else { + // No more NAL units found + return Ok(avcbuflen); + } } buffer_position += 1; zeropad += 1; } - firstloop = false; - - if buffer_position >= avcbuflen { + if buffer_position >= working_len { break; } let nal_start_pos = buffer_position + 1; - let mut nal_stop_pos = avcbuflen; + let mut nal_stop_pos = working_len; buffer_position += 1; - let restlen = avcbuflen.saturating_sub(buffer_position + 2); + let restlen = working_len.saturating_sub(buffer_position + 2); // Use optimized zero search if restlen > 0 { if let Some(zero_offset) = - find_next_zero(&avcbuf[buffer_position..buffer_position + restlen]) + find_next_zero(&working_buf[buffer_position..buffer_position + restlen]) { let zero_pos = buffer_position + zero_offset; - if zero_pos + 2 < avcbuflen { - if avcbuf[zero_pos + 1] == 0x00 && (avcbuf[zero_pos + 2] | 0x01) == 0x01 { + if zero_pos + 2 < working_len { + if working_buf[zero_pos + 1] == 0x00 + && (working_buf[zero_pos + 2] | 0x01) == 0x01 + { nal_stop_pos = zero_pos; buffer_position = zero_pos + 2; } else { // Continue searching from after this zero buffer_position = zero_pos + 1; // Recursive search for next start code - while buffer_position < avcbuflen.saturating_sub(2) { + while buffer_position < working_len.saturating_sub(2) { if let Some(next_zero_offset) = find_next_zero( - &avcbuf[buffer_position..avcbuflen.saturating_sub(2)], + &working_buf[buffer_position..working_len.saturating_sub(2)], ) { let next_zero_pos = buffer_position + next_zero_offset; - if next_zero_pos + 2 < avcbuflen { - if avcbuf[next_zero_pos + 1] == 0x00 - && (avcbuf[next_zero_pos + 2] | 0x01) == 0x01 + if next_zero_pos + 2 < working_len { + if working_buf[next_zero_pos + 1] == 0x00 + && (working_buf[next_zero_pos + 2] | 0x01) == 0x01 { nal_stop_pos = next_zero_pos; buffer_position = next_zero_pos + 2; break; } } else { - nal_stop_pos = avcbuflen; - buffer_position = avcbuflen; + nal_stop_pos = working_len; + buffer_position = working_len; break; } buffer_position = next_zero_pos + 1; } else { - nal_stop_pos = avcbuflen; - buffer_position = avcbuflen; + nal_stop_pos = working_len; + buffer_position = working_len; break; } } } } else { - nal_stop_pos = avcbuflen; - buffer_position = avcbuflen; + nal_stop_pos = working_len; + buffer_position = working_len; } } else { - nal_stop_pos = avcbuflen; - buffer_position = avcbuflen; + nal_stop_pos = working_len; + buffer_position = working_len; } } else { - nal_stop_pos = avcbuflen; - buffer_position = avcbuflen; + nal_stop_pos = working_len; + buffer_position = working_len; } - if nal_start_pos >= avcbuflen { + if nal_start_pos >= working_len { break; } - if (avcbuf[nal_start_pos] & 0x80) != 0 { + if (working_buf[nal_start_pos] & 0x80) != 0 { let dump_start = nal_start_pos.saturating_sub(4); - let dump_len = std::cmp::min(10, avcbuflen - dump_start); - dump(avcbuf[dump_start..].as_ptr(), dump_len as i32, 0, 0); - - return Err(AvcError::ForbiddenZeroBit( - "forbidden_zero_bit not zero".to_string(), - )); + let dump_len = std::cmp::min(10, working_len - dump_start); + dump(working_buf[dump_start..].as_ptr(), dump_len as i32, 0, 0); + + // Don't return an error - just skip this NAL and continue + // This allows processing to continue even with some corrupt data + debug!(msg_type = DebugMessageFlag::VERBOSE; + "Skipping NAL with forbidden_zero_bit set"); + continue; } - (*dec_ctx.avc_ctx).nal_ref_idc = (avcbuf[nal_start_pos] >> 5) as u32; + (*dec_ctx.avc_ctx).nal_ref_idc = (working_buf[nal_start_pos] >> 5) as u32; debug!(msg_type = DebugMessageFlag::VIDEO_STREAM; "process_avc: zeropad {}", zeropad); let nal_length = (nal_stop_pos - nal_start_pos) as i64; - let mut nal_slice = avcbuf[nal_start_pos..nal_stop_pos].to_vec(); + let mut nal_slice = working_buf[nal_start_pos..nal_stop_pos].to_vec(); if let Err(e) = do_nal(enc_ctx, dec_ctx, &mut nal_slice, nal_length, sub) { info!("Error processing NAL unit: {}", e); @@ -504,3 +566,56 @@ pub unsafe fn process_avc( Ok(avcbuflen) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_nal_start_code_3byte() { + // 3-byte start code at position 0 + let buf = [0x00, 0x00, 0x01, 0x65, 0x88]; + assert_eq!(find_nal_start_code(&buf), Some(2)); + } + + #[test] + fn test_find_nal_start_code_4byte() { + // 4-byte start code at position 0 + let buf = [0x00, 0x00, 0x00, 0x01, 0x67, 0x64]; + assert_eq!(find_nal_start_code(&buf), Some(3)); + } + + #[test] + fn test_find_nal_start_code_with_garbage() { + // Garbage data followed by 3-byte start code + let buf = [0xFF, 0xAB, 0xCD, 0x00, 0x00, 0x01, 0x09, 0xF0]; + assert_eq!(find_nal_start_code(&buf), Some(5)); + } + + #[test] + fn test_find_nal_start_code_no_start_code() { + // No start code in buffer + let buf = [0xFF, 0xAB, 0xCD, 0xEF]; + assert_eq!(find_nal_start_code(&buf), None); + } + + #[test] + fn test_find_nal_start_code_too_short() { + // Buffer too short + let buf = [0x00, 0x00]; + assert_eq!(find_nal_start_code(&buf), None); + } + + #[test] + fn test_find_nal_start_code_empty() { + let buf: [u8; 0] = []; + assert_eq!(find_nal_start_code(&buf), None); + } + + #[test] + fn test_find_nal_start_code_partial_match() { + // 0x00 0x00 but no 0x01 following + let buf = [0x00, 0x00, 0x02, 0x00, 0x00, 0x01, 0x65]; + assert_eq!(find_nal_start_code(&buf), Some(5)); + } +}