Skip to content

Commit 8bef238

Browse files
committed
Remove possible panic paths
1 parent 85814c0 commit 8bef238

File tree

9 files changed

+120
-103
lines changed

9 files changed

+120
-103
lines changed

src/asynchronous/fw_update.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,11 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
233233
controllers: &mut [&mut T],
234234
data: &[u8],
235235
) -> Result<(), Error<T::BusError>> {
236-
if controllers.is_empty() {
237-
return Err(PdError::InvalidParams.into());
238-
}
239-
240236
trace!("Controllers: Sending burst write");
241237
let update_args = self.update_args.ok_or(Error::Pd(PdError::InvalidParams))?;
242-
if let Err(e) = controllers[0]
238+
if let Err(e) = controllers
239+
.get_mut(0)
240+
.ok_or(PdError::InvalidParams)?
243241
.fw_update_burst_write(update_args.broadcast_u16_address as u8, data)
244242
.await
245243
{
@@ -434,7 +432,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
434432
read_result.read_data,
435433
);
436434

437-
self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
435+
self.args_buffer
436+
.get_mut(current..current + read_len)
437+
.ok_or(PdError::InvalidParams)?
438+
.copy_from_slice(read_result.read_data);
438439

439440
if read_result.is_complete() {
440441
// We have the full header metadata
@@ -498,7 +499,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
498499
self.block_args = None;
499500
let current = read_result.read_state.current;
500501
let read_len = read_result.read_data.len();
501-
self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
502+
self.args_buffer
503+
.get_mut(current..current + read_len)
504+
.ok_or(PdError::InvalidParams)?
505+
.copy_from_slice(read_result.read_data);
502506

503507
if read_result.is_complete() {
504508
// We have the full header metadata
@@ -573,7 +577,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
573577
) -> Result<Option<SeekOperation>, Error<T::BusError>> {
574578
let current = read_result.read_state.current;
575579
let read_len = read_result.read_data.len();
576-
self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
580+
self.args_buffer
581+
.get_mut(current..current + read_len)
582+
.ok_or(PdError::InvalidParams)?
583+
.copy_from_slice(read_result.read_data);
577584
self.fw_update_burst_write(controllers, read_result.read_data).await?;
578585
if read_result.is_complete() {
579586
// We have the full image size
@@ -702,14 +709,22 @@ pub async fn perform_fw_update_borrowed<T: UpdateTarget>(
702709

703710
// Disable all interrupts while we're entering FW update mode
704711
// These go in the second half of the interrupt_guards array so they get dropped last
705-
disable_all_interrupts(controllers, &mut interrupt_guards[half..]).await?;
712+
disable_all_interrupts(
713+
controllers,
714+
interrupt_guards.get_mut(half..).ok_or(PdError::InvalidParams)?,
715+
)
716+
.await?;
706717
info!("Starting update");
707718
let result = updater.start_fw_update(controllers, delay).await;
708719
info!("Update started");
709720

710721
// Re-enable interrupts on port 0 only
711722
// These go in the first half of the interrupt_guards array so they get dropped first
712-
enable_port0_interrupts(controllers, &mut interrupt_guards[0..half]).await?;
723+
enable_port0_interrupts(
724+
controllers,
725+
interrupt_guards.get_mut(0..half).ok_or(PdError::InvalidParams)?,
726+
)
727+
.await?;
713728

714729
match result {
715730
Err(e) => {

src/asynchronous/internal/command.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ impl<B: I2c> Tps6699x<B> {
9696
debug!("read_command_result: ret: {:?}", ret);
9797
// Overwrite return value
9898
if let Some(data) = data {
99-
data.copy_from_slice(&buf[1..=data.len()]);
99+
data.copy_from_slice(buf.get(1..=data.len()).ok_or(PdError::InvalidParams)?);
100100
}
101101
Ok(ret)
102102
} else {
103103
// No return value to check
104104
debug!("read_command_result: Done");
105105
if let Some(data) = data {
106-
data.copy_from_slice(&buf[..data.len()]);
106+
data.copy_from_slice(buf.get(..data.len()).ok_or(PdError::InvalidParams)?);
107107
}
108108
Ok(ReturnValue::Success)
109109
}

src/asynchronous/internal/mod.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
use device_driver::AsyncRegisterInterface;
33
use embedded_hal_async::i2c::I2c;
44
use embedded_usb_pd::pdinfo::AltMode;
5-
use embedded_usb_pd::pdo::{self, sink, source, ExpectedPdo};
5+
use embedded_usb_pd::pdo::{self, sink, source};
66
use embedded_usb_pd::{Error, LocalPortId, PdError};
77

8-
use crate::registers::rx_caps::EPR_PDO_START_INDEX;
8+
use crate::registers::rx_caps::{RxCapsError, EPR_PDO_START_INDEX};
99
use crate::{
1010
registers, warn, DeviceError, Mode, MAX_SUPPORTED_PORTS, PORT0, PORT1, TPS66993_NUM_PORTS, TPS66994_NUM_PORTS,
1111
};
@@ -45,10 +45,13 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {
4545

4646
buf[0] = address;
4747
buf[1] = data.len() as u8;
48-
let _ = &buf[2..data.len() + 2].copy_from_slice(data);
48+
let _ = &buf
49+
.get_mut(2..data.len() + 2)
50+
.ok_or(PdError::InvalidParams)?
51+
.copy_from_slice(data);
4952

5053
self.bus
51-
.write(self.addr, &buf[..data.len() + 2])
54+
.write(self.addr, buf.get(..data.len() + 2).ok_or(PdError::InvalidParams)?)
5255
.await
5356
.map_err(Error::Bus)
5457
}
@@ -69,7 +72,7 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {
6972
}
7073

7174
self.bus
72-
.write_read(self.addr, &reg, &mut buf[..full_len])
75+
.write_read(self.addr, &reg, buf.get_mut(..full_len).ok_or(PdError::InvalidParams)?)
7376
.await
7477
.map_err(Error::Bus)?;
7578

@@ -88,7 +91,7 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {
8891
// Controller is busy and can't respond
8992
PdError::Busy.into()
9093
} else {
91-
data.copy_from_slice(&buf[1..data.len() + 1]);
94+
data.copy_from_slice(buf.get(1..data.len() + 1).ok_or(PdError::InvalidParams)?);
9295
Ok(())
9396
}
9497
}
@@ -117,11 +120,7 @@ impl<B: I2c> Tps6699x<B> {
117120

118121
/// Get the I2C address for a port
119122
fn port_addr(&self, port: LocalPortId) -> Result<u8, Error<B::Error>> {
120-
if port.0 as usize >= self.num_ports {
121-
PdError::InvalidPort.into()
122-
} else {
123-
Ok(self.addr[port.0 as usize])
124-
}
123+
Ok(*self.addr.get(port.0 as usize).ok_or(PdError::InvalidPort)?)
125124
}
126125

127126
/// Returns number of ports
@@ -602,7 +601,7 @@ impl<B: I2c> Tps6699x<B> {
602601
register: u8,
603602
out_spr_pdos: &mut [T],
604603
out_epr_pdos: &mut [T],
605-
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
604+
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
606605
// Clamp to the maximum number of PDOs
607606
let num_pdos = if !out_epr_pdos.is_empty() {
608607
EPR_PDO_START_INDEX + out_epr_pdos.len()
@@ -626,12 +625,16 @@ impl<B: I2c> Tps6699x<B> {
626625
let num_sprs = out_spr_pdos.len().min(rx_caps.num_valid_pdos() as usize);
627626
for (i, pdo) in out_spr_pdos.iter_mut().enumerate().take(num_sprs) {
628627
// SPR PDOs start at index 0
629-
*pdo = rx_caps[i];
628+
*pdo = *rx_caps
629+
.get(i)
630+
.ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?;
630631
}
631632

632633
let num_eprs = out_epr_pdos.len().min(rx_caps.num_valid_epr_pdos() as usize);
633634
for (i, pdo) in out_epr_pdos.iter_mut().enumerate().take(num_eprs) {
634-
*pdo = rx_caps[EPR_PDO_START_INDEX + i];
635+
*pdo = *rx_caps
636+
.get(EPR_PDO_START_INDEX + i)
637+
.ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?;
635638
}
636639

637640
Ok((num_sprs, num_eprs))
@@ -645,7 +648,7 @@ impl<B: I2c> Tps6699x<B> {
645648
port: LocalPortId,
646649
out_spr_pdos: &mut [source::Pdo],
647650
out_epr_pdos: &mut [source::Pdo],
648-
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
651+
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
649652
self.get_rx_caps(port, registers::rx_caps::RX_SRC_ADDR, out_spr_pdos, out_epr_pdos)
650653
.await
651654
}
@@ -658,7 +661,7 @@ impl<B: I2c> Tps6699x<B> {
658661
port: LocalPortId,
659662
out_spr_pdos: &mut [sink::Pdo],
660663
out_epr_pdos: &mut [sink::Pdo],
661-
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
664+
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
662665
self.get_rx_caps(port, registers::rx_caps::RX_SNK_ADDR, out_spr_pdos, out_epr_pdos)
663666
.await
664667
}

src/asynchronous/interrupt.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@ pub trait InterruptController {
2727
port: LocalPortId,
2828
enabled: bool,
2929
) -> Result<Self::Guard, Error<Self::BusError>> {
30-
if port.0 as usize >= MAX_SUPPORTED_PORTS {
31-
return PdError::InvalidPort.into();
32-
}
33-
3430
let mut state = self.interrupts_enabled().await?;
35-
state[port.0 as usize] = enabled;
31+
*state.get_mut(port.0 as usize).ok_or(PdError::InvalidPort)? = enabled;
3632
self.enable_interrupts_guarded(state).await
3733
}
3834

src/command/mod.rs

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,44 @@ use bincode::error::{DecodeError, EncodeError};
44
use bincode::{Decode, Encode};
55
use embedded_usb_pd::PdError;
66

7+
use crate::u32_from_str;
8+
79
pub mod gcdm;
810
pub mod muxr;
911
pub mod trig;
1012
pub mod vdms;
1113

12-
/// Length of a command
13-
const CMD_LEN: usize = 4;
14-
1514
/// TaskResult is only defined for lower 4 bits
1615
pub const CMD_4CC_TASK_RETURN_CODE_MASK: u8 = 0x0F;
1716

18-
/// Converts a 4-byte string into a u32
19-
const fn u32_from_str(value: &str) -> u32 {
20-
if value.len() != CMD_LEN {
21-
panic!("Invalid command string")
22-
}
23-
24-
let bytes = value.as_bytes();
25-
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le()
26-
}
27-
2817
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2918
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
3019
#[repr(u32)]
3120
pub enum Command {
3221
/// Previous command succeeded
3322
Success = 0,
3423
/// Invalid Command
35-
Invalid = u32_from_str("!CMD"),
24+
Invalid = u32_from_str(*b"!CMD"),
3625
/// Reset command
37-
Gaid = u32_from_str("GAID"),
26+
Gaid = u32_from_str(*b"GAID"),
3827

3928
/// Tomcat firmware update mode enter
40-
Tfus = u32_from_str("TFUs"),
29+
Tfus = u32_from_str(*b"TFUs"),
4130
/// Tomcat firmware update mode init
42-
Tfui = u32_from_str("TFUi"),
31+
Tfui = u32_from_str(*b"TFUi"),
4332
/// Tomcat firmware update mode query
44-
Tfuq = u32_from_str("TFUq"),
33+
Tfuq = u32_from_str(*b"TFUq"),
4534
/// Tomcat firmware update mode exit
46-
Tfue = u32_from_str("TFUe"),
35+
Tfue = u32_from_str(*b"TFUe"),
4736
/// Tomcat firmware update data
48-
Tfud = u32_from_str("TFUd"),
37+
Tfud = u32_from_str(*b"TFUd"),
4938
/// Tomcat firmware update complete
50-
Tfuc = u32_from_str("TFUc"),
39+
Tfuc = u32_from_str(*b"TFUc"),
5140

5241
/// System ready to sink
53-
Srdy = u32_from_str("SRDY"),
42+
Srdy = u32_from_str(*b"SRDY"),
5443
/// SRDY reset
55-
Sryr = u32_from_str("SRYR"),
44+
Sryr = u32_from_str(*b"SRYR"),
5645

5746
/// Re-evaluate the Autonegotiate Sink register.
5847
///
@@ -61,10 +50,10 @@ pub enum Command {
6150
///
6251
/// # Output
6352
/// [`ReturnValue`]
64-
Aneg = u32_from_str("ANeg"),
53+
Aneg = u32_from_str(*b"ANeg"),
6554

6655
/// Trigger an Input GPIO event
67-
Trig = u32_from_str("Trig"),
56+
Trig = u32_from_str(*b"Trig"),
6857

6958
/// Clear the dead battery flag.
7059
///
@@ -73,7 +62,7 @@ pub enum Command {
7362
///
7463
/// # Output
7564
/// [`ReturnValue`]
76-
Dbfg = u32_from_str("DBfg"),
65+
Dbfg = u32_from_str(*b"DBfg"),
7766

7867
/// Repeat transactions on I2C3m under certain conditions.
7968
///
@@ -82,7 +71,7 @@ pub enum Command {
8271
///
8372
/// # Output
8473
/// [`ReturnValue`]
85-
Muxr = u32_from_str("MuxR"),
74+
Muxr = u32_from_str(*b"MuxR"),
8675

8776
/// PD Data Reset
8877
///
@@ -91,7 +80,7 @@ pub enum Command {
9180
///
9281
/// # Output
9382
/// [`ReturnValue`]
94-
Drst = u32_from_str("DRST"),
83+
Drst = u32_from_str(*b"DRST"),
9584

9685
/// Send VDM.
9786
///
@@ -100,7 +89,7 @@ pub enum Command {
10089
///
10190
/// # Output
10291
/// None
103-
VDMs = u32_from_str("VDMs"),
92+
VDMs = u32_from_str(*b"VDMs"),
10493

10594
/// Execute a UCSI command
10695
///
@@ -109,7 +98,7 @@ pub enum Command {
10998
///
11099
/// # Output
111100
/// [`embedded_usb_pd::ucsi::lpm::ResponseData`]
112-
Ucsi = u32_from_str("UCSI"),
101+
Ucsi = u32_from_str(*b"UCSI"),
113102

114103
/// Get custom discovered modes
115104
///
@@ -118,7 +107,7 @@ pub enum Command {
118107
///
119108
/// # Output
120109
/// [`gcdm::DiscoveredMode`]
121-
GCdm = u32_from_str("GCdm"),
110+
GCdm = u32_from_str(*b"GCdm"),
122111
}
123112

124113
impl TryFrom<u32> for Command {

src/lib.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ impl<BE, T> From<DeviceError<BE, T>> for embedded_usb_pd::Error<BE> {
5656
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
5757
pub enum Mode {
5858
/// Boot mode
59-
Boot = u32_from_str("BOOT"),
59+
Boot = u32_from_str(*b"BOOT"),
6060
/// Firmware corrupt on both banks
61-
F211 = u32_from_str("F211"),
61+
F211 = u32_from_str(*b"F211"),
6262
/// Before app config
63-
App0 = u32_from_str("APP0"),
63+
App0 = u32_from_str(*b"APP0"),
6464
/// After app config
65-
App1 = u32_from_str("APP1"),
65+
App1 = u32_from_str(*b"APP1"),
6666
/// App FW waiting for power
67-
Wtpr = u32_from_str("WTPR"),
67+
Wtpr = u32_from_str(*b"WTPR"),
6868
}
6969

7070
impl PartialEq<u32> for Mode {
@@ -102,12 +102,8 @@ impl Into<[u8; 4]> for Mode {
102102

103103
const U32_STR_LEN: usize = 4;
104104
/// Converts a 4-byte string into a u32
105-
pub(crate) const fn u32_from_str(value: &str) -> u32 {
106-
if value.len() != U32_STR_LEN {
107-
panic!("Invalid command string")
108-
}
109-
let bytes = value.as_bytes();
110-
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le()
105+
pub(crate) const fn u32_from_str(bytes: [u8; U32_STR_LEN]) -> u32 {
106+
u32::from_le_bytes(bytes).to_le()
111107
}
112108

113109
/// Common unit test functions

0 commit comments

Comments
 (0)