diff --git a/esp-hal/src/spi/master/dma.rs b/esp-hal/src/spi/master/dma.rs index e4adde9815..49d62de844 100644 --- a/esp-hal/src/spi/master/dma.rs +++ b/esp-hal/src/spi/master/dma.rs @@ -1,4 +1,5 @@ use core::{ + cell::Cell, cmp::min, mem::ManuallyDrop, ops::{Deref, DerefMut}, @@ -66,7 +67,7 @@ impl<'d> Spi<'d, Blocking> { /// ``` #[instability::unstable] pub fn with_dma(self, channel: impl DmaChannelFor>) -> SpiDma<'d, Blocking> { - SpiDma::new(self.spi, self.pins, channel.degrade()) + SpiDma::new(self, channel.degrade()) } } @@ -114,14 +115,10 @@ pub struct SpiDma<'d, Dm> where Dm: DriverMode, { - pub(crate) spi: AnySpi<'d>, + spi: SpiWrapper<'d>, pub(crate) channel: Channel>>, - tx_transfer_in_progress: bool, - rx_transfer_in_progress: bool, #[cfg(all(esp32, spi_address_workaround))] address_buffer: DmaTxBuf, - guard: PeripheralGuard, - pins: SpiPinGuard, } impl crate::private::Sealed for SpiDma<'_, Dm> where Dm: DriverMode {} @@ -129,27 +126,25 @@ impl crate::private::Sealed for SpiDma<'_, Dm> where Dm: DriverMode {} impl<'d> SpiDma<'d, Blocking> { /// Converts the SPI instance into async mode. #[instability::unstable] - pub fn into_async(mut self) -> SpiDma<'d, Async> { - self.set_interrupt_handler(self.spi.info().async_handler); + pub fn into_async(self) -> SpiDma<'d, Async> { + self.spi + .set_interrupt_handler(self.spi.info().async_handler); SpiDma { spi: self.spi, channel: self.channel.into_async(), - tx_transfer_in_progress: self.tx_transfer_in_progress, - rx_transfer_in_progress: self.rx_transfer_in_progress, #[cfg(all(esp32, spi_address_workaround))] address_buffer: self.address_buffer, - guard: self.guard, - pins: self.pins, } } pub(super) fn new( - spi: AnySpi<'d>, - pins: SpiPinGuard, + spi_driver: Spi<'d, Blocking>, channel: PeripheralDmaChannel>, ) -> Self { + let spi = spi_driver.spi; + let channel = Channel::new(channel); - channel.runtime_ensure_compatible(&spi); + channel.runtime_ensure_compatible(&spi.spi); #[cfg(all(esp32, spi_address_workaround))] let address_buffer = { use crate::dma::DmaDescriptor; @@ -170,17 +165,16 @@ impl<'d> SpiDma<'d, Blocking> { )) }; - let guard = PeripheralGuard::new(spi.info().peripheral); + let (_info, state) = spi.spi.dma_parts(); + + state.tx_transfer_in_progress.set(false); + state.rx_transfer_in_progress.set(false); Self { spi, channel, #[cfg(all(esp32, spi_address_workaround))] address_buffer, - tx_transfer_in_progress: false, - rx_transfer_in_progress: false, - guard, - pins, } } @@ -241,19 +235,15 @@ impl<'d> SpiDma<'d, Async> { SpiDma { spi: self.spi, channel: self.channel.into_blocking(), - tx_transfer_in_progress: self.tx_transfer_in_progress, - rx_transfer_in_progress: self.rx_transfer_in_progress, #[cfg(all(esp32, spi_address_workaround))] address_buffer: self.address_buffer, - guard: self.guard, - pins: self.pins, } } async fn wait_for_idle_async(&mut self) { - if self.rx_transfer_in_progress { + if self.dma_driver().state.rx_transfer_in_progress.get() { _ = DmaRxFuture::new(&mut self.channel.rx).await; - self.rx_transfer_in_progress = false; + self.dma_driver().state.rx_transfer_in_progress.set(false); } struct Fut(Driver); @@ -292,19 +282,19 @@ impl<'d> SpiDma<'d, Async> { Fut(self.driver()).await; } - if self.tx_transfer_in_progress { + if self.dma_driver().state.tx_transfer_in_progress.get() { // In case DMA TX buffer is bigger than what the SPI consumes, stop the DMA. if !self.channel.tx.is_done() { self.channel.tx.stop_transfer(); } - self.tx_transfer_in_progress = false; + self.dma_driver().state.tx_transfer_in_progress.set(false); } } } impl core::fmt::Debug for SpiDma<'_, Dm> where - Dm: DriverMode, + Dm: DriverMode + core::fmt::Debug, { /// Formats the `SpiDma` instance for debugging purposes. /// @@ -329,6 +319,10 @@ impl SpiDma<'_, Dm> where Dm: DriverMode, { + fn spi(&self) -> &SpiWrapper<'_> { + &self.spi + } + fn driver(&self) -> Driver { Driver { info: self.spi.info(), @@ -339,7 +333,8 @@ where fn dma_driver(&self) -> DmaDriver { DmaDriver { driver: self.driver(), - dma_peripheral: self.spi.dma_peripheral(), + dma_peripheral: self.spi().dma_peripheral(), + state: self.spi().dma_state(), } } @@ -347,7 +342,7 @@ where if self.driver().busy() { return false; } - if self.rx_transfer_in_progress { + if self.dma_driver().state.rx_transfer_in_progress.get() { // If this is an asymmetric transfer and the RX side is smaller, the RX channel // will never be "done" as it won't have enough descriptors/buffer to receive // the EOF bit from the SPI. So instead the RX channel will hit @@ -366,8 +361,8 @@ where while !self.is_done() { // Wait for the SPI to become idle } - self.rx_transfer_in_progress = false; - self.tx_transfer_in_progress = false; + self.dma_driver().state.rx_transfer_in_progress.set(false); + self.dma_driver().state.tx_transfer_in_progress.set(false); fence(Ordering::Acquire); } @@ -388,8 +383,14 @@ where return Err(Error::MaxDmaTransferSizeExceeded); } - self.rx_transfer_in_progress = bytes_to_read > 0; - self.tx_transfer_in_progress = bytes_to_write > 0; + self.dma_driver() + .state + .rx_transfer_in_progress + .set(bytes_to_read > 0); + self.dma_driver() + .state + .tx_transfer_in_progress + .set(bytes_to_write > 0); unsafe { self.dma_driver().start_transfer_dma( full_duplex, @@ -432,9 +433,8 @@ where address.mode(), )?; - // FIXME: we could use self.start_transfer_dma if the address buffer was part of - // the (yet-to-be-created) State struct. - self.tx_transfer_in_progress = true; + // FIXME: we could use self.start_transfer_dma + self.dma_driver().state.tx_transfer_in_progress.set(true); unsafe { self.dma_driver().start_transfer_dma( false, @@ -448,23 +448,18 @@ where } fn cancel_transfer(&mut self) { - // The SPI peripheral is controlling how much data we transfer, so let's - // update its counter. - // 0 doesn't take effect on ESP32 and cuts the currently transmitted byte - // immediately. - // 1 seems to stop after transmitting the current byte which is somewhat less - // impolite. - if self.tx_transfer_in_progress || self.rx_transfer_in_progress { + let state = self.dma_driver().state; + if state.tx_transfer_in_progress.get() || state.rx_transfer_in_progress.get() { self.dma_driver().abort_transfer(); // We need to stop the DMA transfer, too. - if self.tx_transfer_in_progress { + if state.tx_transfer_in_progress.get() { self.channel.tx.stop_transfer(); - self.tx_transfer_in_progress = false; + state.tx_transfer_in_progress.set(false); } - if self.rx_transfer_in_progress { + if state.rx_transfer_in_progress.get() { self.channel.rx.stop_transfer(); - self.rx_transfer_in_progress = false; + state.rx_transfer_in_progress.set(false); } } } @@ -1247,10 +1242,17 @@ where pub(super) struct DmaDriver { driver: Driver, dma_peripheral: crate::dma::DmaPeripheral, + state: &'static DmaState, } impl DmaDriver { fn abort_transfer(&self) { + // The SPI peripheral is controlling how much data we transfer, so let's + // update its counter. + // 0 doesn't take effect on ESP32 and cuts the currently transmitted byte + // immediately. + // 1 seems to stop after transmitting the current byte which is somewhat less + // impolite. self.driver.configure_datalen(1, 1); self.driver.update(); } @@ -1389,12 +1391,8 @@ impl<'d> DmaEligible for AnySpi<'d> { type Dma = crate::dma::AnySpiDmaChannel<'d>; fn dma_peripheral(&self) -> crate::dma::DmaPeripheral { - match &self.0 { - #[cfg(spi_master_spi2)] - any::Inner::Spi2(_) => crate::dma::DmaPeripheral::Spi2, - #[cfg(spi_master_spi3)] - any::Inner::Spi3(_) => crate::dma::DmaPeripheral::Spi3, - } + let (info, _state) = self.dma_parts(); + info.dma_peripheral } } @@ -1496,3 +1494,64 @@ where Ok(()) } } + +struct DmaInfo { + dma_peripheral: crate::dma::DmaPeripheral, +} +struct DmaState { + tx_transfer_in_progress: Cell, + rx_transfer_in_progress: Cell, +} + +// SAFETY: State belongs to the currently constructed driver instance. As such, it'll not be +// accessed concurrently in multiple threads. +unsafe impl Sync for DmaState {} + +for_each_spi_master!( + (all $( ($peri:ident, $sys:ident, $sclk:ident $_cs:tt $_sio:tt $(, $is_qspi:tt)?)),* ) => { + impl AnySpi<'_> { + #[inline(always)] + fn dma_parts(&self) -> (&'static DmaInfo, &'static DmaState) { + match &self.0 { + $( + super::any::Inner::$sys(_spi) => { + static DMA_INFO: DmaInfo = DmaInfo { + dma_peripheral: crate::dma::DmaPeripheral::$sys, + }; + + static DMA_STATE: DmaState = DmaState { + tx_transfer_in_progress: Cell::new(false), + rx_transfer_in_progress: Cell::new(false), + }; + + (&DMA_INFO, &DMA_STATE) + } + )* + } + } + + #[inline(always)] + fn dma_state(&self) -> &'static DmaState { + let (_, state) = self.dma_parts(); + state + } + + #[inline(always)] + fn dma_info(&self) -> &'static DmaInfo { + let (info, _) = self.dma_parts(); + info + } + } + }; +); + +impl SpiWrapper<'_> { + fn dma_state(&self) -> &'static DmaState { + self.spi.dma_state() + } + + #[inline(always)] + fn dma_peripheral(&self) -> crate::dma::DmaPeripheral { + self.spi.dma_peripheral() + } +} diff --git a/esp-hal/src/spi/master/mod.rs b/esp-hal/src/spi/master/mod.rs index 94103623cc..5ef00c6a2e 100644 --- a/esp-hal/src/spi/master/mod.rs +++ b/esp-hal/src/spi/master/mod.rs @@ -37,8 +37,10 @@ #[cfg(esp32)] use core::cell::Cell; use core::{ + cell::UnsafeCell, future::Future, marker::PhantomData, + mem::MaybeUninit, pin::Pin, task::{Context, Poll}, }; @@ -639,6 +641,16 @@ struct SpiPinGuard { sio_pins: [PinGuard; SIO_PIN_COUNT], } +impl SpiPinGuard { + const fn new_unconnected() -> Self { + Self { + sclk_pin: PinGuard::new_unconnected(), + cs_pin: PinGuard::new_unconnected(), + sio_pins: [const { PinGuard::new_unconnected() }; SIO_PIN_COUNT], + } + } +} + /// Configuration errors. #[non_exhaustive] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -659,6 +671,66 @@ impl core::fmt::Display for ConfigError { } } } + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct SpiWrapper<'d> { + spi: AnySpi<'d>, + _guard: PeripheralGuard, +} + +impl<'d> SpiWrapper<'d> { + fn new(spi: impl Instance + 'd) -> Self { + let p = spi.info().peripheral; + let this = Self { + spi: spi.degrade(), + _guard: PeripheralGuard::new(p), + }; + + // Initialize state + unsafe { + this.state() + .pins + .get() + .write(MaybeUninit::new(SpiPinGuard::new_unconnected())) + } + + this + } + + fn info(&self) -> &'static Info { + self.spi.info() + } + + fn state(&self) -> &'static State { + self.spi.state() + } + + fn disable_peri_interrupt_on_all_cores(&self) { + self.spi.disable_peri_interrupt_on_all_cores(); + } + + fn set_interrupt_handler(&self, handler: InterruptHandler) { + self.spi.set_interrupt_handler(handler); + } + + fn pins(&mut self) -> &mut SpiPinGuard { + unsafe { + // SAFETY: we "own" the state, we are allowed to borrow it mutably + self.state().pins() + } + } +} + +impl Drop for SpiWrapper<'_> { + fn drop(&mut self) { + unsafe { + // SAFETY: we "own" the state, we are allowed to deinit it + self.spi.state().deinit(); + } + } +} + #[procmacros::doc_replace] /// SPI peripheral driver /// @@ -684,10 +756,8 @@ impl core::fmt::Display for ConfigError { #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Spi<'d, Dm: DriverMode> { - spi: AnySpi<'d>, + spi: SpiWrapper<'d>, _mode: PhantomData, - guard: PeripheralGuard, - pins: SpiPinGuard, } impl Sealed for Spi<'_, Dm> {} @@ -715,17 +785,9 @@ impl<'d> Spi<'d, Blocking> { /// # {after_snippet} /// ``` pub fn new(spi: impl Instance + 'd, config: Config) -> Result { - let guard = PeripheralGuard::new(spi.info().peripheral); - let mut this = Spi { _mode: PhantomData, - guard, - pins: SpiPinGuard { - sclk_pin: PinGuard::new_unconnected(), - cs_pin: PinGuard::new_unconnected(), - sio_pins: [const { PinGuard::new_unconnected() }; SIO_PIN_COUNT], - }, - spi: spi.degrade(), + spi: SpiWrapper::new(spi), }; this.driver().init(); @@ -754,8 +816,6 @@ impl<'d> Spi<'d, Blocking> { Spi { spi: self.spi, _mode: PhantomData, - guard: self.guard, - pins: self.pins, } } @@ -803,8 +863,6 @@ impl<'d> Spi<'d, Async> { Spi { spi: self.spi, _mode: PhantomData, - guard: self.guard, - pins: self.pins, } } @@ -901,7 +959,7 @@ macro_rules! def_with_sio_pin { #[doc = concat!(" to the SIO", stringify!($n), " output and input signals.")] #[instability::unstable] pub fn $fn(mut self, sio: impl PeripheralInput<'d> + PeripheralOutput<'d>) -> Self { - self.pins.sio_pins[$n] = self.connect_sio_pin(sio.into(), $n); + self.spi.pins().sio_pins[$n] = self.connect_sio_pin(sio.into(), $n); self } @@ -964,7 +1022,8 @@ where /// # {after_snippet} /// ``` pub fn with_sck(mut self, sclk: impl PeripheralOutput<'d>) -> Self { - self.pins.sclk_pin = self.connect_output_pin(sclk.into(), self.driver().info.sclk); + let info = self.spi.info(); + self.spi.pins().sclk_pin = self.connect_output_pin(sclk.into(), info.sclk); self } @@ -992,8 +1051,7 @@ where /// # {after_snippet} /// ``` pub fn with_mosi(mut self, mosi: impl PeripheralOutput<'d>) -> Self { - self.pins.sio_pins[0] = self.connect_sio_output_pin(mosi.into(), 0); - + self.spi.pins().sio_pins[0] = self.connect_sio_output_pin(mosi.into(), 0); self } @@ -1043,7 +1101,7 @@ where /// signal. #[instability::unstable] pub fn with_sio0(mut self, mosi: impl PeripheralInput<'d> + PeripheralOutput<'d>) -> Self { - self.pins.sio_pins[0] = self.connect_sio_pin(mosi.into(), 0); + self.spi.pins().sio_pins[0] = self.connect_sio_pin(mosi.into(), 0); self } @@ -1061,7 +1119,7 @@ where /// signal. #[instability::unstable] pub fn with_sio1(mut self, sio1: impl PeripheralInput<'d> + PeripheralOutput<'d>) -> Self { - self.pins.sio_pins[1] = self.connect_sio_pin(sio1.into(), 1); + self.spi.pins().sio_pins[1] = self.connect_sio_pin(sio1.into(), 1); self } @@ -1094,7 +1152,9 @@ where /// mechanism to select which CS line to use. #[instability::unstable] pub fn with_cs(mut self, cs: impl PeripheralOutput<'d>) -> Self { - self.pins.cs_pin = self.connect_output_pin(cs.into(), self.driver().info.cs(0)); + let info = self.spi.info(); + self.spi.pins().cs_pin = self.connect_output_pin(cs.into(), info.cs(0)); + self } @@ -2417,6 +2477,8 @@ for_each_spi_master! { static STATE: State = State { waker: AtomicWaker::new(), + pins: UnsafeCell::new(MaybeUninit::uninit()), + #[cfg(esp32)] esp32_hack: Esp32Hack { timing_miso_delay: Cell::new(None), @@ -2441,19 +2503,43 @@ impl QspiInstance for AnySpi<'_> {} #[doc(hidden)] pub struct State { waker: AtomicWaker, + pins: UnsafeCell>, #[cfg(esp32)] esp32_hack: Esp32Hack, } +impl State { + // Syntactic helper to get a mutable reference to the pin guard. + // + // Intended to be called in `SpiWrapper::pins` only + // + // # Safety + // + // The caller must ensure that Rust's aliasing rules are upheld. + #[allow( + clippy::mut_from_ref, + reason = "Safety requirements ensure this is okay" + )] + unsafe fn pins(&self) -> &mut SpiPinGuard { + unsafe { (&mut *self.pins.get()).assume_init_mut() } + } + + unsafe fn deinit(&self) { + unsafe { + let mut old = self.pins.get().replace(MaybeUninit::uninit()); + old.assume_init_drop(); + } + } +} + #[cfg(esp32)] struct Esp32Hack { timing_miso_delay: Cell>, extra_dummy: Cell, } -#[cfg(esp32)] -unsafe impl Sync for Esp32Hack {} +unsafe impl Sync for State {} #[ram] fn handle_async(info: &'static Info, state: &'static State) {