From b474516d777f8796f10caddfee718e1468f722b0 Mon Sep 17 00:00:00 2001 From: "Eric B. Ridge" Date: Wed, 19 Jul 2023 01:51:27 +0000 Subject: [PATCH] Round 1 of cleaning up `SpiTupleTable` - Create `SpiTupleTable::wrap(...)` to construct a new one from the global `pg_sys::SPI_tuptable`, update `SpiCursor` and `Query` impls accordingly. This gets rid of `SpiClient::prepare_tuple_table()`. - Make its members private to fix cross-module leakiness - When dropped, `SpiTupleTable` needs to free its internal `pg_sys::SPITupleTable` pointer Other changes: - Make the lifetime of what `SpiCursor::fetch()` returns explicit - There's no need to set `pg_sys::SPI_tuptable` to NULL before calling SPI_execute/SPI_cursor_fetch -- Postgres does that for us - add a test that passed before this work and still does --- pgrx-tests/src/tests/spi_tests.rs | 25 +++++++++++ pgrx/src/spi/client.rs | 24 +---------- pgrx/src/spi/cursor.rs | 12 ++---- pgrx/src/spi/query.rs | 9 +--- pgrx/src/spi/tuple.rs | 70 +++++++++++++++++++++++++++---- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/pgrx-tests/src/tests/spi_tests.rs b/pgrx-tests/src/tests/spi_tests.rs index 52f02845d4..ef6095e877 100644 --- a/pgrx-tests/src/tests/spi_tests.rs +++ b/pgrx-tests/src/tests/spi_tests.rs @@ -242,6 +242,31 @@ mod tests { }) } + #[pg_test] + fn two_cursors_at_the_same_time() -> Result<(), pgrx::spi::Error> { + Spi::connect(|client| { + let mut cursor1 = client.open_cursor("SELECT * FROM generate_series(1, 20)", None); + let mut cursor2 = client.open_cursor("SELECT * FROM generate_series(40, 60)", None); + + let first_5 = cursor1.fetch(5)?; + let second_5 = cursor2.fetch(5)?; + + let first_5 = Vec::from_iter(first_5.map(|row| row.get::(1))); + let second_5 = Vec::from_iter(second_5.map(|row| row.get::(1))); + + assert_eq!( + first_5, + vec![Ok(Some(1)), Ok(Some(2)), Ok(Some(3)), Ok(Some(4)), Ok(Some(5))] + ); + assert_eq!( + second_5, + vec![Ok(Some(40)), Ok(Some(41)), Ok(Some(42)), Ok(Some(43)), Ok(Some(44))] + ); + + Ok(()) + }) + } + #[pg_test] fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> { let cursor_name = Spi::connect(|client| { diff --git a/pgrx/src/spi/client.rs b/pgrx/src/spi/client.rs index 172923de90..a54cb3827f 100644 --- a/pgrx/src/spi/client.rs +++ b/pgrx/src/spi/client.rs @@ -5,6 +5,7 @@ use std::ptr::NonNull; use crate::pg_sys::{self, PgOid}; use crate::spi::{PreparedStatement, Query, Spi, SpiCursor, SpiError, SpiResult, SpiTupleTable}; +#[derive(Debug)] pub struct SpiClient { // We need `SpiClient` to be publicly accessible but not constructable because we rely // on it being properly constructed in order for its Drop impl, which calles `pg_sys::SPI_finish()`, @@ -76,29 +77,6 @@ impl SpiClient { query.execute(self, limit, args) } - pub(super) fn prepare_tuple_table( - &self, - status_code: i32, - ) -> std::result::Result { - Ok(SpiTupleTable { - status_code: Spi::check_status(status_code)?, - // SAFETY: no concurrent access - table: unsafe { pg_sys::SPI_tuptable.as_mut()}, - #[cfg(any(feature = "pg11", feature = "pg12"))] - size: unsafe { pg_sys::SPI_processed as usize }, - #[cfg(not(any(feature = "pg11", feature = "pg12")))] - // SAFETY: no concurrent access - size: unsafe { - if pg_sys::SPI_tuptable.is_null() { - pg_sys::SPI_processed as usize - } else { - (*pg_sys::SPI_tuptable).numvals as usize - } - }, - current: -1, - }) - } - /// Set up a cursor that will execute the specified query /// /// Rows may be then fetched using [`SpiCursor::fetch`]. diff --git a/pgrx/src/spi/cursor.rs b/pgrx/src/spi/cursor.rs index a7433b0670..952c1d96e4 100644 --- a/pgrx/src/spi/cursor.rs +++ b/pgrx/src/spi/cursor.rs @@ -3,7 +3,7 @@ use std::ptr::NonNull; use crate::pg_sys; -use super::{SpiClient, SpiError, SpiOkCodes, SpiTupleTable}; +use super::{SpiClient, SpiOkCodes, SpiResult, SpiTupleTable}; type CursorName = String; @@ -67,18 +67,14 @@ pub struct SpiCursor<'client> { pub(crate) client: &'client SpiClient, } -impl SpiCursor<'_> { +impl<'client> SpiCursor<'client> { /// Fetch up to `count` rows from the cursor, moving forward /// /// If `fetch` runs off the end of the available rows, an empty [`SpiTupleTable`] is returned. - pub fn fetch(&mut self, count: libc::c_long) -> std::result::Result { - // SAFETY: no concurrent access - unsafe { - pg_sys::SPI_tuptable = std::ptr::null_mut(); - } + pub fn fetch(&mut self, count: libc::c_long) -> SpiResult> { // SAFETY: SPI functions to create/find cursors fail via elog, so self.ptr is valid if we successfully set it unsafe { pg_sys::SPI_cursor_fetch(self.ptr.as_mut(), true, count) } - Ok(self.client.prepare_tuple_table(SpiOkCodes::Fetch as i32)?) + SpiTupleTable::wrap(&self.client, SpiOkCodes::Fetch as i32) } /// Consume the cursor, returning its name diff --git a/pgrx/src/spi/query.rs b/pgrx/src/spi/query.rs index 4c6fc24a1d..60a0090ab3 100644 --- a/pgrx/src/spi/query.rs +++ b/pgrx/src/spi/query.rs @@ -62,11 +62,6 @@ impl<'client> Query<'client> for &str { limit: Option, arguments: Self::Arguments, ) -> SpiResult> { - // SAFETY: no concurrent access - unsafe { - pg_sys::SPI_tuptable = std::ptr::null_mut(); - } - let src = CString::new(self).expect("query contained a null byte"); let status_code = match arguments { Some(args) => { @@ -99,7 +94,7 @@ impl<'client> Query<'client> for &str { }, }; - Ok(client.prepare_tuple_table(status_code)?) + SpiTupleTable::wrap(client, status_code) } fn open_cursor(self, client: &'client SpiClient, args: Self::Arguments) -> SpiCursor<'client> { @@ -238,7 +233,7 @@ impl<'client: 'stmt, 'stmt> Query<'client> for &'stmt PreparedStatement<'client> ) }; - Ok(client.prepare_tuple_table(status_code)?) + SpiTupleTable::wrap(client, status_code) } fn open_cursor(self, client: &'client SpiClient, args: Self::Arguments) -> SpiCursor<'client> { diff --git a/pgrx/src/spi/tuple.rs b/pgrx/src/spi/tuple.rs index 214fb95777..d58765f80c 100644 --- a/pgrx/src/spi/tuple.rs +++ b/pgrx/src/spi/tuple.rs @@ -8,19 +8,56 @@ use crate::memcxt::PgMemoryContexts; use crate::pg_sys::panic::ErrorReportable; use crate::pg_sys::{self, PgOid}; use crate::prelude::*; +use crate::spi::SpiClient; -use super::{SpiError, SpiErrorCodes, SpiOkCodes, SpiResult}; +use super::{SpiError, SpiErrorCodes, SpiResult}; #[derive(Debug)] pub struct SpiTupleTable<'client> { - #[allow(dead_code)] - pub(super) status_code: SpiOkCodes, - pub(super) table: Option<&'client mut pg_sys::SPITupleTable>, - pub(super) size: usize, - pub(super) current: isize, + // SpiTupleTable borrows global state setup by the active SpiClient. It doesn't use the client + // directly, but we need to make sure we don't outlive it, so here it is + _client: PhantomData<&'client SpiClient>, + + // and this is that global state. In ::wrap(), this comes from whatever the current value of + // `pg_sys::SPI_tuptable` happens to be. Postgres may change where SPI_tuptable points + // throughout the lifetime of an active SpiClient, but it doesn't mutate (or deallocate) what + // it happens to point to This allows us to have multiple active SpiTupleTables + // within a Spi connection. Whatever this points to is freed via `pg_sys::SPI_freetuptable()` + // when we're dropped. + table: Option>, + size: usize, + current: isize, } impl<'client> SpiTupleTable<'client> { + /// Wraps the current global `pg_sys::SPI_tuptable` as a new [`SpiTupleTable`] instance, with + /// a lifetime tied to the specified [`SpiClient`]. + pub(super) fn wrap(_client: &'client SpiClient, last_spi_status_code: i32) -> SpiResult { + Spi::check_status(last_spi_status_code)?; + + unsafe { + // + // SAFETY: The unsafeness here is that we're accessing static globals. Fortunately, + // Postgres is not multi-threaded so we're okay to do this + // + + // different Postgres get the tuptable size different ways + #[cfg(any(feature = "pg11", feature = "pg12"))] + let size = pg_sys::SPI_processed as usize; + + #[cfg(not(any(feature = "pg11", feature = "pg12")))] + let size = if pg_sys::SPI_tuptable.is_null() { + pg_sys::SPI_processed as usize + } else { + (*pg_sys::SPI_tuptable).numvals as usize + }; + + let tuptable = pg_sys::SPI_tuptable; + + Ok(Self { _client: PhantomData, table: NonNull::new(tuptable), size, current: -1 }) + } + } + /// `SpiTupleTable`s are positioned before the start, for iteration purposes. /// /// This method moves the position to the first row. If there are no rows, this @@ -76,9 +113,12 @@ impl<'client> SpiTupleTable<'client> { fn get_spi_tuptable( &self, ) -> SpiResult<(*mut pg_sys::SPITupleTable, *mut pg_sys::TupleDescData)> { - let table = self.table.as_deref().ok_or(SpiError::NoTupleTable)?; - // SAFETY: we just assured that `table` is not null - Ok((table as *const _ as *mut _, table.tupdesc)) + let table = self.table.map(|table| table.as_ptr()).ok_or(SpiError::NoTupleTable)?; + let tupdesc = unsafe { + // SAFETY: we just assured that `table` is not null + table.as_mut().unwrap().tupdesc + }; + Ok((table, tupdesc)) } pub fn get_heap_tuple(&self) -> SpiResult>> { @@ -298,6 +338,18 @@ impl<'client> Iterator for SpiTupleTable<'client> { } } +impl Drop for SpiTupleTable<'_> { + fn drop(&mut self) { + unsafe { + // SAFETY: self.table was created by Postgres from whatever `pg_sys::SPI_tuptable` pointed + // to at the time this SpiTupleTable was constructed + if let Some(ptr) = self.table.take() { + pg_sys::SPI_freetuptable(ptr.as_ptr()) + } + } + } +} + /// Represents a single `pg_sys::Datum` inside a `SpiHeapTupleData` pub struct SpiHeapTupleDataEntry<'client> { datum: Option,