diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 95eb2623bd..e331acf013 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -212,7 +212,7 @@ impl MySqlConnection { // otherwise, this first packet is the start of the result-set metadata, *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row; - let num_columns = packet.get_uint_lenenc(); // column count + let num_columns = packet.get_uint_lenenc()?; // column count let num_columns = usize::try_from(num_columns) .map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?; diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index ff931b2f46..e6aa8b48c8 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -194,7 +194,7 @@ impl MySqlStream { } async fn skip_result_metadata(&mut self, mut packet: Packet) -> Result<(), Error> { - let num_columns: u64 = packet.get_uint_lenenc(); // column count + let num_columns: u64 = packet.get_uint_lenenc()?; // column count for _ in 0..num_columns { let _ = self.recv_packet().await?; diff --git a/sqlx-mysql/src/io/buf.rs b/sqlx-mysql/src/io/buf.rs index 685d5bfda7..6b3c11b3f5 100644 --- a/sqlx-mysql/src/io/buf.rs +++ b/sqlx-mysql/src/io/buf.rs @@ -8,7 +8,7 @@ pub trait MySqlBufExt: Buf { // NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL. // NOTE: 0xff is only returned during a result set to indicate ERR. // - fn get_uint_lenenc(&mut self) -> u64; + fn get_uint_lenenc(&mut self) -> Result; // Read a length-encoded string. #[allow(dead_code)] @@ -19,18 +19,46 @@ pub trait MySqlBufExt: Buf { } impl MySqlBufExt for Bytes { - fn get_uint_lenenc(&mut self) -> u64 { + fn get_uint_lenenc(&mut self) -> Result { + if self.remaining() < 1 { + return Err(err_protocol!("lenenc int: no bytes remaining")); + } + match self.get_u8() { - 0xfc => u64::from(self.get_u16_le()), - 0xfd => self.get_uint_le(3), - 0xfe => self.get_u64_le(), + 0xfc => { + if self.remaining() < 2 { + return Err(err_protocol!( + "lenenc int: need 2 more bytes, have {}", + self.remaining() + )); + } + Ok(u64::from(self.get_u16_le())) + } + 0xfd => { + if self.remaining() < 3 { + return Err(err_protocol!( + "lenenc int: need 3 more bytes, have {}", + self.remaining() + )); + } + Ok(self.get_uint_le(3)) + } + 0xfe => { + if self.remaining() < 8 { + return Err(err_protocol!( + "lenenc int: need 8 more bytes, have {}", + self.remaining() + )); + } + Ok(self.get_u64_le()) + } - v => u64::from(v), + v => Ok(u64::from(v)), } } fn get_str_lenenc(&mut self) -> Result { - let size = self.get_uint_lenenc(); + let size = self.get_uint_lenenc()?; let size = usize::try_from(size) .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; @@ -38,7 +66,7 @@ impl MySqlBufExt for Bytes { } fn get_bytes_lenenc(&mut self) -> Result { - let size = self.get_uint_lenenc(); + let size = self.get_uint_lenenc()?; let size = usize::try_from(size) .map_err(|_| err_protocol!("string length overflows usize: {size}"))?; diff --git a/sqlx-mysql/src/protocol/response/ok.rs b/sqlx-mysql/src/protocol/response/ok.rs index 74c4abded7..86fea9b4cf 100644 --- a/sqlx-mysql/src/protocol/response/ok.rs +++ b/sqlx-mysql/src/protocol/response/ok.rs @@ -24,8 +24,16 @@ impl ProtocolDecode<'_> for OkPacket { )); } - let affected_rows = buf.get_uint_lenenc(); - let last_insert_id = buf.get_uint_lenenc(); + let affected_rows = buf.get_uint_lenenc()?; + let last_insert_id = buf.get_uint_lenenc()?; + + if buf.remaining() < 4 { + return Err(err_protocol!( + "OK_Packet too short: expected at least 4 more bytes for status+warnings, got {}", + buf.remaining() + )); + } + let status = Status::from_bits_truncate(buf.get_u16_le()); let warnings = buf.get_u16_le(); @@ -76,3 +84,11 @@ fn test_decode_ok_packet_with_extended_info() { assert_eq!(p.warnings, 1); assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); } + +#[test] +fn test_decode_ok_packet_truncated() { + const DATA: &[u8] = b"\x00\x00\x00\x01"; + + let err = OkPacket::decode(DATA.into()).unwrap_err(); + assert!(matches!(err, Error::Protocol(_)), "{err}"); +} diff --git a/sqlx-mysql/src/protocol/statement/row.rs b/sqlx-mysql/src/protocol/statement/row.rs index a55701fe19..4e9cc2bb9c 100644 --- a/sqlx-mysql/src/protocol/statement/row.rs +++ b/sqlx-mysql/src/protocol/statement/row.rs @@ -76,7 +76,7 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow { | ColumnType::Decimal | ColumnType::Json | ColumnType::NewDecimal => { - let size = buf.get_uint_lenenc(); + let size = buf.get_uint_lenenc()?; usize::try_from(size) .map_err(|_| err_protocol!("BLOB length out of range: {size}"))? } diff --git a/sqlx-mysql/src/protocol/text/column.rs b/sqlx-mysql/src/protocol/text/column.rs index 6e33713880..b7c9c7e639 100644 --- a/sqlx-mysql/src/protocol/text/column.rs +++ b/sqlx-mysql/src/protocol/text/column.rs @@ -147,7 +147,7 @@ impl ProtocolDecode<'_, Capabilities> for ColumnDefinition { let table = buf.get_bytes_lenenc()?; let alias = buf.get_bytes_lenenc()?; let name = buf.get_bytes_lenenc()?; - let _next_len = buf.get_uint_lenenc(); // always 0x0c + let _next_len = buf.get_uint_lenenc()?; // always 0x0c let collation = buf.get_u16_le(); let max_size = buf.get_u32_le(); let type_id = buf.get_u8(); diff --git a/sqlx-mysql/src/protocol/text/row.rs b/sqlx-mysql/src/protocol/text/row.rs index e5f820c653..53f7c472ee 100644 --- a/sqlx-mysql/src/protocol/text/row.rs +++ b/sqlx-mysql/src/protocol/text/row.rs @@ -22,7 +22,7 @@ impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow { values.push(None); buf.advance(1); } else { - let size = buf.get_uint_lenenc(); + let size = buf.get_uint_lenenc()?; if (buf.remaining() as u64) < size { return Err(err_protocol!( "buffer exhausted when reading data for column {:?}; decoded length is {}, but only {} bytes remain in buffer. Malformed packet or protocol error?",