Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl MySqlConnection {
// otherwise, this first packet is the start of the result-set metadata,
*self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;

let num_columns = packet.get_uint_lenenc(); // column count
let num_columns = packet.get_uint_lenenc()?; // column count
let num_columns = usize::try_from(num_columns)
.map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?;

Expand Down
2 changes: 1 addition & 1 deletion sqlx-mysql/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl<S: Socket> MySqlStream<S> {
}

async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
let num_columns: u64 = packet.get_uint_lenenc(); // column count
let num_columns: u64 = packet.get_uint_lenenc()?; // column count

for _ in 0..num_columns {
let _ = self.recv_packet().await?;
Expand Down
44 changes: 36 additions & 8 deletions sqlx-mysql/src/io/buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub trait MySqlBufExt: Buf {
// NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL.
// NOTE: 0xff is only returned during a result set to indicate ERR.
// <https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger>
fn get_uint_lenenc(&mut self) -> u64;
fn get_uint_lenenc(&mut self) -> Result<u64, Error>;

// Read a length-encoded string.
#[allow(dead_code)]
Expand All @@ -19,26 +19,54 @@ pub trait MySqlBufExt: Buf {
}

impl MySqlBufExt for Bytes {
fn get_uint_lenenc(&mut self) -> u64 {
fn get_uint_lenenc(&mut self) -> Result<u64, Error> {
if self.remaining() < 1 {
return Err(err_protocol!("lenenc int: no bytes remaining"));
}

match self.get_u8() {
0xfc => u64::from(self.get_u16_le()),
0xfd => self.get_uint_le(3),
0xfe => self.get_u64_le(),
0xfc => {
if self.remaining() < 2 {
return Err(err_protocol!(
"lenenc int: need 2 more bytes, have {}",
self.remaining()
));
}
Ok(u64::from(self.get_u16_le()))
}
0xfd => {
if self.remaining() < 3 {
return Err(err_protocol!(
"lenenc int: need 3 more bytes, have {}",
self.remaining()
));
}
Ok(self.get_uint_le(3))
}
0xfe => {
if self.remaining() < 8 {
return Err(err_protocol!(
"lenenc int: need 8 more bytes, have {}",
self.remaining()
));
}
Ok(self.get_u64_le())
}

v => u64::from(v),
v => Ok(u64::from(v)),
}
}

fn get_str_lenenc(&mut self) -> Result<String, Error> {
let size = self.get_uint_lenenc();
let size = self.get_uint_lenenc()?;
let size = usize::try_from(size)
.map_err(|_| err_protocol!("string length overflows usize: {size}"))?;

self.get_str(size)
}

fn get_bytes_lenenc(&mut self) -> Result<Bytes, Error> {
let size = self.get_uint_lenenc();
let size = self.get_uint_lenenc()?;
let size = usize::try_from(size)
.map_err(|_| err_protocol!("string length overflows usize: {size}"))?;

Expand Down
20 changes: 18 additions & 2 deletions sqlx-mysql/src/protocol/response/ok.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ impl ProtocolDecode<'_> for OkPacket {
));
}

let affected_rows = buf.get_uint_lenenc();
let last_insert_id = buf.get_uint_lenenc();
let affected_rows = buf.get_uint_lenenc()?;
let last_insert_id = buf.get_uint_lenenc()?;

if buf.remaining() < 4 {
return Err(err_protocol!(
"OK_Packet too short: expected at least 4 more bytes for status+warnings, got {}",
buf.remaining()
));
}

let status = Status::from_bits_truncate(buf.get_u16_le());
let warnings = buf.get_u16_le();

Expand Down Expand Up @@ -76,3 +84,11 @@ fn test_decode_ok_packet_with_extended_info() {
assert_eq!(p.warnings, 1);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
}

#[test]
fn test_decode_ok_packet_truncated() {
const DATA: &[u8] = b"\x00\x00\x00\x01";

let err = OkPacket::decode(DATA.into()).unwrap_err();
assert!(matches!(err, Error::Protocol(_)), "{err}");
}
2 changes: 1 addition & 1 deletion sqlx-mysql/src/protocol/statement/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow {
| ColumnType::Decimal
| ColumnType::Json
| ColumnType::NewDecimal => {
let size = buf.get_uint_lenenc();
let size = buf.get_uint_lenenc()?;
usize::try_from(size)
.map_err(|_| err_protocol!("BLOB length out of range: {size}"))?
}
Expand Down
2 changes: 1 addition & 1 deletion sqlx-mysql/src/protocol/text/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl ProtocolDecode<'_, Capabilities> for ColumnDefinition {
let table = buf.get_bytes_lenenc()?;
let alias = buf.get_bytes_lenenc()?;
let name = buf.get_bytes_lenenc()?;
let _next_len = buf.get_uint_lenenc(); // always 0x0c
let _next_len = buf.get_uint_lenenc()?; // always 0x0c
let collation = buf.get_u16_le();
let max_size = buf.get_u32_le();
let type_id = buf.get_u8();
Expand Down
2 changes: 1 addition & 1 deletion sqlx-mysql/src/protocol/text/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow {
values.push(None);
buf.advance(1);
} else {
let size = buf.get_uint_lenenc();
let size = buf.get_uint_lenenc()?;
if (buf.remaining() as u64) < size {
return Err(err_protocol!(
"buffer exhausted when reading data for column {:?}; decoded length is {}, but only {} bytes remain in buffer. Malformed packet or protocol error?",
Expand Down