Skip to content

Commit 3ca5b7a

Browse files
JBYoshiJonathanWoollett-Light
authored andcommitted
Change packet lengths to sized types
IPv4 packet header sizes fit in a u8 value, and all IP/Ethernet-related packet total sizes fit in a u16 value. This changes those sizes to use smaller types that don't ever need to be converted to smaller types. Signed-off-by: Jonathan Browne <[email protected]>
1 parent fd5662b commit 3ca5b7a

File tree

6 files changed

+85
-71
lines changed

6 files changed

+85
-71
lines changed

src/vmm/src/dumbo/pdu/ipv4.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const PROTOCOL_OFFSET: usize = 9;
2525
const HEADER_CHECKSUM_OFFSET: usize = 10;
2626
const SOURCE_ADDRESS_OFFSET: usize = 12;
2727
const DESTINATION_ADDRESS_OFFSET: usize = 16;
28-
const OPTIONS_OFFSET: usize = 20;
28+
const OPTIONS_OFFSET: u8 = 20;
2929

3030
/// Indicates version 4 of the IP protocol
3131
pub const IPV4_VERSION: u8 = 0x04;
@@ -82,7 +82,7 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
8282
pub fn from_bytes(bytes: T, verify_checksum: bool) -> Result<Self, Error> {
8383
let bytes_len = bytes.len();
8484

85-
if bytes_len < OPTIONS_OFFSET {
85+
if bytes_len < usize::from(OPTIONS_OFFSET) {
8686
return Err(Error::SliceTooShort);
8787
}
8888

@@ -96,7 +96,7 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
9696

9797
let total_len = packet.total_len() as usize;
9898

99-
if total_len < header_len {
99+
if total_len < header_len.into() {
100100
return Err(Error::InvalidTotalLen);
101101
}
102102

@@ -111,7 +111,7 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
111111
// We ignore the TTL field since only routers should care about it. An end host has no
112112
// reason really to discard an otherwise valid packet.
113113

114-
if verify_checksum && packet.compute_checksum_unchecked(header_len) != 0 {
114+
if verify_checksum && packet.compute_checksum_unchecked(header_len.into()) != 0 {
115115
return Err(Error::Checksum);
116116
}
117117

@@ -123,16 +123,16 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
123123
/// This method returns the actual length (in bytes) of the header, and not the value of the
124124
/// `ihl` header field).
125125
#[inline]
126-
pub fn version_and_header_len(&self) -> (u8, usize) {
126+
pub fn version_and_header_len(&self) -> (u8, u8) {
127127
let x = self.bytes[VERSION_AND_IHL_OFFSET];
128128
let ihl = x & 0x0f;
129-
let header_len = (ihl << 2) as usize;
129+
let header_len = ihl << 2;
130130
(x >> 4, header_len)
131131
}
132132

133133
/// Returns the packet header length (in bytes).
134134
#[inline]
135-
pub fn header_len(&self) -> usize {
135+
pub fn header_len(&self) -> u8 {
136136
let (_, header_len) = self.version_and_header_len();
137137
header_len
138138
}
@@ -207,7 +207,7 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
207207
/// Returns a byte slice that contains the payload of the packet.
208208
#[inline]
209209
pub fn payload(&self) -> &[u8] {
210-
self.payload_unchecked(self.header_len())
210+
self.payload_unchecked(self.header_len().into())
211211
}
212212

213213
/// Returns the length of the inner byte sequence.
@@ -245,7 +245,7 @@ impl<'a, T: NetworkBytes + Debug> IPv4Packet<'a, T> {
245245
/// Computes and returns the packet header checksum.
246246
#[inline]
247247
pub fn compute_checksum(&self) -> u16 {
248-
self.compute_checksum_unchecked(self.header_len())
248+
self.compute_checksum_unchecked(self.header_len().into())
249249
}
250250
}
251251

@@ -263,7 +263,7 @@ impl<'a, T: NetworkBytesMut + Debug> IPv4Packet<'a, T> {
263263
src_addr: Ipv4Addr,
264264
dst_addr: Ipv4Addr,
265265
) -> Result<Incomplete<Self>, Error> {
266-
if buf.len() < OPTIONS_OFFSET {
266+
if buf.len() < usize::from(OPTIONS_OFFSET) {
267267
return Err(Error::SliceTooShort);
268268
}
269269
let mut packet = IPv4Packet::from_bytes_unchecked(buf);
@@ -283,9 +283,9 @@ impl<'a, T: NetworkBytesMut + Debug> IPv4Packet<'a, T> {
283283
/// Sets the values of the `version` and `ihl` header fields (the latter is computed from the
284284
/// value of `header_len`).
285285
#[inline]
286-
pub fn set_version_and_header_len(&mut self, version: u8, header_len: usize) -> &mut Self {
286+
pub fn set_version_and_header_len(&mut self, version: u8, header_len: u8) -> &mut Self {
287287
let version = version << 4;
288-
let ihl = ((header_len as u8) >> 2) & 0xf;
288+
let ihl = (header_len >> 2) & 0xf;
289289
self.bytes[VERSION_AND_IHL_OFFSET] = version | ihl;
290290
self
291291
}
@@ -374,7 +374,7 @@ impl<'a, T: NetworkBytesMut + Debug> IPv4Packet<'a, T> {
374374
// Can't use self.header_len() as a fn parameter on the following line, because
375375
// the borrow checker complains. This may change when it becomes smarter.
376376
let header_len = self.header_len();
377-
self.payload_mut_unchecked(header_len)
377+
self.payload_mut_unchecked(header_len.into())
378378
}
379379
}
380380

@@ -394,24 +394,24 @@ impl<'a, T: NetworkBytesMut + Debug> Incomplete<IPv4Packet<'a, T>> {
394394
#[inline]
395395
pub fn with_header_and_payload_len_unchecked(
396396
mut self,
397-
header_len: usize,
398-
payload_len: usize,
397+
header_len: u8,
398+
payload_len: u16,
399399
compute_checksum: bool,
400400
) -> IPv4Packet<'a, T> {
401-
let total_len = header_len + payload_len;
401+
let total_len = u16::from(header_len) + payload_len;
402402
{
403403
let packet = &mut self.inner;
404404

405405
// This unchecked is fine as long as total_len is smaller than the length of the
406406
// original slice, which should be the case if our code is not wrong.
407-
packet.bytes.shrink_unchecked(total_len);
407+
packet.bytes.shrink_unchecked(total_len.into());
408408
// Set the total_len.
409-
packet.set_total_len(total_len as u16);
409+
packet.set_total_len(total_len);
410410
if compute_checksum {
411411
// Ensure this is set to 0 first.
412412
packet.set_header_checksum(0);
413413
// Now compute the actual checksum.
414-
let checksum = packet.compute_checksum_unchecked(header_len);
414+
let checksum = packet.compute_checksum_unchecked(header_len.into());
415415
packet.set_header_checksum(checksum);
416416
}
417417
}
@@ -427,8 +427,8 @@ impl<'a, T: NetworkBytesMut + Debug> Incomplete<IPv4Packet<'a, T>> {
427427
#[inline]
428428
pub fn with_options_and_payload_len_unchecked(
429429
self,
430-
options_len: usize,
431-
payload_len: usize,
430+
options_len: u8,
431+
payload_len: u16,
432432
compute_checksum: bool,
433433
) -> IPv4Packet<'a, T> {
434434
let header_len = OPTIONS_OFFSET + options_len;
@@ -444,7 +444,7 @@ impl<'a, T: NetworkBytesMut + Debug> Incomplete<IPv4Packet<'a, T>> {
444444
#[inline]
445445
pub fn with_payload_len_unchecked(
446446
self,
447-
payload_len: usize,
447+
payload_len: u16,
448448
compute_checksum: bool,
449449
) -> IPv4Packet<'a, T> {
450450
let header_len = self.inner().header_len();
@@ -457,7 +457,7 @@ impl<'a, T: NetworkBytesMut + Debug> Incomplete<IPv4Packet<'a, T>> {
457457
#[inline]
458458
pub fn test_speculative_dst_addr(buf: &[u8], addr: Ipv4Addr) -> bool {
459459
// The unchecked methods are safe because we actually check the buffer length beforehand.
460-
if buf.len() >= ethernet::PAYLOAD_OFFSET + OPTIONS_OFFSET {
460+
if buf.len() >= ethernet::PAYLOAD_OFFSET + usize::from(OPTIONS_OFFSET) {
461461
let bytes = &buf[ethernet::PAYLOAD_OFFSET..];
462462
if IPv4Packet::from_bytes_unchecked(bytes).destination_address() == addr {
463463
return true;

src/vmm/src/dumbo/pdu/mod.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,28 +79,26 @@ fn compute_checksum<T: NetworkBytes + Debug>(
7979
dst_addr: Ipv4Addr,
8080
protocol: ChecksumProto,
8181
) -> u16 {
82-
// TODO: Is u32 enough to prevent overflow for the code in this function? I think so, but it
83-
// would be nice to double-check.
84-
let mut sum = 0u32;
82+
let mut sum = 0usize;
8583

86-
let a = u32::from(src_addr);
84+
let a = u32::from(src_addr) as usize;
8785
sum += a & 0xffff;
8886
sum += a >> 16;
8987

90-
let b = u32::from(dst_addr);
88+
let b = u32::from(dst_addr) as usize;
9189
sum += b & 0xffff;
9290
sum += b >> 16;
9391

9492
let len = bytes.len();
95-
sum += protocol as u32;
96-
sum += len as u32;
93+
sum += protocol as usize;
94+
sum += len;
9795

9896
for i in 0..len / 2 {
99-
sum += u32::from(bytes.ntohs_unchecked(i * 2));
97+
sum += usize::from(bytes.ntohs_unchecked(i * 2));
10098
}
10199

102100
if len % 2 != 0 {
103-
sum += u32::from(bytes[len - 1]) << 8;
101+
sum += usize::from(bytes[len - 1]) << 8;
104102
}
105103

106104
while sum >> 16 != 0 {

src/vmm/src/dumbo/pdu/tcp.rs

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ const WINDOW_SIZE_OFFSET: usize = 14;
3030
const CHECKSUM_OFFSET: usize = 16;
3131
const URG_POINTER_OFFSET: usize = 18;
3232

33-
const OPTIONS_OFFSET: usize = 20;
33+
const OPTIONS_OFFSET: u8 = 20;
3434

35-
const MAX_HEADER_LEN: usize = 60;
35+
const MAX_HEADER_LEN: u8 = 60;
3636

3737
const OPTION_KIND_EOL: u8 = 0x00;
3838
const OPTION_KIND_NOP: u8 = 0x01;
@@ -127,18 +127,18 @@ impl<'a, T: NetworkBytes + Debug> TcpSegment<'a, T> {
127127
/// Returns the header length, the value of the reserved bits, and whether the `NS` flag
128128
/// is set or not.
129129
#[inline]
130-
pub fn header_len_rsvd_ns(&self) -> (usize, u8, bool) {
130+
pub fn header_len_rsvd_ns(&self) -> (u8, u8, bool) {
131131
let value = self.bytes[DATAOFF_RSVD_NS_OFFSET];
132132
let data_offset = value >> 4;
133-
let header_len = data_offset as usize * 4;
133+
let header_len = data_offset * 4;
134134
let rsvd = value & 0x0e;
135135
let ns = (value & 1) != 0;
136136
(header_len, rsvd, ns)
137137
}
138138

139139
/// Returns the length of the header.
140140
#[inline]
141-
pub fn header_len(&self) -> usize {
141+
pub fn header_len(&self) -> u8 {
142142
self.header_len_rsvd_ns().0
143143
}
144144

@@ -174,7 +174,7 @@ impl<'a, T: NetworkBytes + Debug> TcpSegment<'a, T> {
174174
/// This method may panic if the value of `header_len` is invalid.
175175
#[inline]
176176
pub fn options_unchecked(&self, header_len: usize) -> &[u8] {
177-
&self.bytes[OPTIONS_OFFSET..header_len]
177+
&self.bytes[usize::from(OPTIONS_OFFSET)..header_len]
178178
}
179179

180180
/// Returns a slice which contains the payload of the segment. May panic if the value of
@@ -190,20 +190,23 @@ impl<'a, T: NetworkBytes + Debug> TcpSegment<'a, T> {
190190

191191
/// Returns the length of the segment.
192192
#[inline]
193-
pub fn len(&self) -> usize {
194-
self.bytes.len()
193+
pub fn len(&self) -> u16 {
194+
// NOTE: This appears to be a safe conversion in all current cases.
195+
// Packets are always set up in the context of an Ipv4Packet, which is
196+
// capped at a u16 size. However, I'd rather be safe here.
197+
u16::try_from(self.bytes.len()).unwrap_or(u16::MAX)
195198
}
196199

197200
/// Returns a slice which contains the payload of the segment.
198201
#[inline]
199202
pub fn payload(&self) -> &[u8] {
200-
self.payload_unchecked(self.header_len())
203+
self.payload_unchecked(self.header_len().into())
201204
}
202205

203206
/// Returns the length of the payload.
204207
#[inline]
205-
pub fn payload_len(&self) -> usize {
206-
self.len() - self.header_len()
208+
pub fn payload_len(&self) -> u16 {
209+
self.len() - u16::from(self.header_len())
207210
}
208211

209212
/// Computes the TCP checksum of the segment. More details about TCP checksum computation can
@@ -285,7 +288,7 @@ impl<'a, T: NetworkBytes + Debug> TcpSegment<'a, T> {
285288
bytes: T,
286289
verify_checksum: Option<(Ipv4Addr, Ipv4Addr)>,
287290
) -> Result<Self, Error> {
288-
if bytes.len() < OPTIONS_OFFSET {
291+
if bytes.len() < usize::from(OPTIONS_OFFSET) {
289292
return Err(Error::SliceTooShort);
290293
}
291294

@@ -295,7 +298,9 @@ impl<'a, T: NetworkBytes + Debug> TcpSegment<'a, T> {
295298

296299
let header_len = segment.header_len();
297300

298-
if header_len < OPTIONS_OFFSET || header_len > min(MAX_HEADER_LEN, segment.len()) {
301+
if header_len < OPTIONS_OFFSET
302+
|| u16::from(header_len) > min(u16::from(MAX_HEADER_LEN), segment.len())
303+
{
299304
return Err(Error::HeaderLen);
300305
}
301306

@@ -342,8 +347,8 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
342347
/// of 4), clears the reserved bits, and sets the `NS` flag according to the last parameter.
343348
// TODO: Check that header_len | 0b11 == 0 and the resulting data_offset is valid?
344349
#[inline]
345-
pub fn set_header_len_rsvd_ns(&mut self, header_len: usize, ns: bool) -> &mut Self {
346-
let mut value = (header_len as u8) << 2;
350+
pub fn set_header_len_rsvd_ns(&mut self, header_len: u8, ns: bool) -> &mut Self {
351+
let mut value = header_len << 2;
347352
if ns {
348353
value |= 1;
349354
}
@@ -393,7 +398,7 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
393398
#[inline]
394399
pub fn payload_mut(&mut self) -> &mut [u8] {
395400
let header_len = self.header_len();
396-
self.payload_mut_unchecked(header_len)
401+
self.payload_mut_unchecked(header_len.into())
397402
}
398403

399404
/// Writes a complete TCP segment.
@@ -479,24 +484,24 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
479484
mss_remaining: u16,
480485
payload: Option<(&R, usize)>,
481486
) -> Result<Incomplete<Self>, Error> {
482-
let mut mss_left = mss_remaining as usize;
487+
let mut mss_left = mss_remaining;
483488

484489
// We're going to need at least this many bytes.
485-
let mut segment_len = OPTIONS_OFFSET;
490+
let mut segment_len = u16::from(OPTIONS_OFFSET);
486491

487492
// The TCP options will require this much more bytes.
488493
let options_len = if mss_option.is_some() {
489494
mss_left = mss_left
490495
.checked_sub(OPTION_LEN_MSS.into())
491496
.ok_or(Error::MssRemaining)?;
492-
usize::from(OPTION_LEN_MSS)
497+
OPTION_LEN_MSS
493498
} else {
494499
0
495500
};
496501

497-
segment_len += options_len;
502+
segment_len += u16::from(options_len);
498503

499-
if buf.len() < segment_len {
504+
if buf.len() < usize::from(segment_len) {
500505
return Err(Error::SliceTooShort);
501506
}
502507

@@ -513,9 +518,11 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
513518

514519
// Let's write the MSS option if we have to.
515520
if let Some(value) = mss_option {
516-
segment.bytes[OPTIONS_OFFSET] = OPTION_KIND_MSS;
517-
segment.bytes[OPTIONS_OFFSET + 1] = OPTION_LEN_MSS;
518-
segment.bytes.htons_unchecked(OPTIONS_OFFSET + 2, value);
521+
segment.bytes[usize::from(OPTIONS_OFFSET)] = OPTION_KIND_MSS;
522+
segment.bytes[usize::from(OPTIONS_OFFSET) + 1] = OPTION_LEN_MSS;
523+
segment
524+
.bytes
525+
.htons_unchecked(usize::from(OPTIONS_OFFSET) + 2, value);
519526
}
520527

521528
let payload_bytes_count = if let Some((payload_buf, max_payload_bytes)) = payload {
@@ -524,7 +531,9 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
524531
// The subtraction makes sense because we previously checked that
525532
// buf.len() >= segment_len.
526533
let mut room_for_payload = min(segment.len() - segment_len, mss_left);
527-
room_for_payload = min(room_for_payload, left_to_read);
534+
// The unwrap is safe because room_for_payload is a u16.
535+
room_for_payload =
536+
u16::try_from(min(usize::from(room_for_payload), left_to_read)).unwrap();
528537

529538
if room_for_payload == 0 {
530539
return Err(Error::EmptyPayload);
@@ -535,7 +544,8 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
535544
// `offset + room_for_payload <= payload_buf.len()`.
536545
payload_buf.read_to_slice(
537546
0,
538-
&mut segment.bytes[segment_len..segment_len + room_for_payload],
547+
&mut segment.bytes
548+
[usize::from(segment_len)..usize::from(segment_len + room_for_payload)],
539549
);
540550
room_for_payload
541551
} else {
@@ -544,7 +554,7 @@ impl<'a, T: NetworkBytesMut + Debug> TcpSegment<'a, T> {
544554
segment_len += payload_bytes_count;
545555

546556
// This is ok because segment_len <= buf.len().
547-
segment.bytes.shrink_unchecked(segment_len);
557+
segment.bytes.shrink_unchecked(segment_len.into());
548558

549559
// Shrink the resulting segment to a slice of exact size, so using self.len() makes sense.
550560
Ok(Incomplete::new(segment))

0 commit comments

Comments
 (0)