Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sqlx-mysql/src/protocol/response/ok.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ impl ProtocolDecode<'_> for OkPacket {

let affected_rows = buf.get_uint_lenenc();
let last_insert_id = buf.get_uint_lenenc();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to move the goalposts too much, but it seems like these fields should at least be covered by the length check too.

In the long run we should probably ban use of bytes::Buf methods and instead write our own that return Result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking get_uint_lenenc calls is a bit trickier since the method advances a variable number of bytes depending on the prefix. I've changed its signature to return Result and added remaining-bytes checks inside the method, propagating with ? at all call sites.

If this approach is good, I could follow up with a separate PR that wraps the other panicking bytes::Buf methods in the same way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly thinking that we should just check that the packet is nonempty, but this is a great improvement, thanks!

If this approach is good, I could follow up with a separate PR that wraps the other panicking bytes::Buf methods in the same way.

I think that would be a good idea, there's a lot of shortcuts that we took early on that are coming back to bite us and using things like this that panic instead of returning errors is a big one.


if buf.remaining() < 4 {
return Err(err_protocol!(
"OK_Packet too short: expected at least 4 more bytes for status+warnings, got {}",
buf.remaining()
));
}

let status = Status::from_bits_truncate(buf.get_u16_le());
let warnings = buf.get_u16_le();

Expand Down Expand Up @@ -76,3 +84,11 @@ fn test_decode_ok_packet_with_extended_info() {
assert_eq!(p.warnings, 1);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
}

#[test]
fn test_decode_ok_packet_truncated() {
const DATA: &[u8] = b"\x00\x00\x00\x01";

let err = OkPacket::decode(DATA.into()).unwrap_err();
assert!(matches!(err, Error::Protocol(_)), "{err}");
}
Loading