diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index dd5f962b2..64220ce9c 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import ctypes import datetime import os import re +import threading +import time from typing import Any import pyarrow as pa @@ -2060,3 +2063,121 @@ def test_fill_null_all_null_column(ctx): # Check that all nulls were filled result = filled_df.collect()[0] assert result.column(1).to_pylist() == ["filled", "filled", "filled"] + + +def test_collect_interrupted(): + """Test that a long-running query can be interrupted with Ctrl-C. + + This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt + exception in the main thread during a long-running query execution. + """ + # Create a context and a DataFrame with a query that will run for a while + ctx = SessionContext() + + # Create a recursive computation that will run for some time + batches = [] + for i in range(10): + batch = pa.RecordBatch.from_arrays( + [ + pa.array(list(range(i * 1000, (i + 1) * 1000))), + pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), + ], + names=["a", "b"], + ) + batches.append(batch) + + # Register tables + ctx.register_record_batches("t1", [batches]) + ctx.register_record_batches("t2", [batches]) + + # Create a large join operation that will take time to process + df = ctx.sql(""" + WITH t1_expanded AS ( + SELECT + a, + b, + CAST(a AS DOUBLE) / 1.5 AS c, + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d + FROM t1 + CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) + ), + t2_expanded AS ( + SELECT + a, + b, + CAST(a AS DOUBLE) * 2.5 AS e, + CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f + FROM t2 + CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) + ) + SELECT + t1.a, t1.b, t1.c, t1.d, + t2.a AS a2, t2.b AS b2, t2.e, t2.f + FROM t1_expanded t1 + JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 + WHERE t1.a > 100 AND t2.a > 100 + """) + + # Flag to track if the query was interrupted + interrupted = False + interrupt_error = None + main_thread = threading.main_thread() + + # Shared flag to indicate query execution has started + query_started = threading.Event() + max_wait_time = 5.0 # Maximum wait time in seconds + + # This function will be run in a separate thread and will raise + # KeyboardInterrupt in the main thread + def trigger_interrupt(): + """Poll for query start, then raise KeyboardInterrupt in the main thread""" + # Poll for query to start with small sleep intervals + start_time = time.time() + while not query_started.is_set(): + time.sleep(0.1) # Small sleep between checks + if time.time() - start_time > max_wait_time: + msg = f"Query did not start within {max_wait_time} seconds" + raise RuntimeError(msg) + + # Check if thread ID is available + thread_id = main_thread.ident + if thread_id is None: + msg = "Cannot get main thread ID" + raise RuntimeError(msg) + + # Use ctypes to raise exception in main thread + exception = ctypes.py_object(KeyboardInterrupt) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), exception + ) + if res != 1: + # If res is 0, the thread ID was invalid + # If res > 1, we modified multiple threads + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(thread_id), ctypes.py_object(0) + ) + msg = "Failed to raise KeyboardInterrupt in main thread" + raise RuntimeError(msg) + + # Start a thread to trigger the interrupt + interrupt_thread = threading.Thread(target=trigger_interrupt) + # we mark as daemon so the test process can exit even if this thread doesn't finish + interrupt_thread.daemon = True + interrupt_thread.start() + + # Execute the query and expect it to be interrupted + try: + # Signal that we're about to start the query + query_started.set() + df.collect() + except KeyboardInterrupt: + interrupted = True + except Exception as e: + interrupt_error = e + + # Assert that the query was interrupted properly + if not interrupted: + pytest.fail(f"Query was not interrupted; got error: {interrupt_error}") + + # Make sure the interrupt thread has finished + interrupt_thread.join(timeout=1.0) diff --git a/src/catalog.rs b/src/catalog.rs index 1e189a5aa..83f8d08cb 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -97,7 +97,7 @@ impl PyDatabase { } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))? { + if let Some(table) = wait_for_future(py, self.database.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( diff --git a/src/context.rs b/src/context.rs index cc3d8e8e9..b0af566e4 100644 --- a/src/context.rs +++ b/src/context.rs @@ -34,7 +34,7 @@ use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, PyDataFusionResult}; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; @@ -375,7 +375,7 @@ impl PySessionContext { None => { let state = self.ctx.state(); let schema = options.infer_schema(&state, &table_path); - wait_for_future(py, schema)? + wait_for_future(py, schema)?? } }; let config = ListingTableConfig::new(table_path) @@ -400,7 +400,7 @@ impl PySessionContext { /// Returns a PyDataFrame whose plan corresponds to the SQL statement. pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult { let result = self.ctx.sql(query); - let df = wait_for_future(py, result)?; + let df = wait_for_future(py, result)??; Ok(PyDataFrame::new(df)) } @@ -417,7 +417,7 @@ impl PySessionContext { SQLOptions::new() }; let result = self.ctx.sql_with_options(query, options); - let df = wait_for_future(py, result)?; + let df = wait_for_future(py, result)??; Ok(PyDataFrame::new(df)) } @@ -451,7 +451,7 @@ impl PySessionContext { self.ctx.register_table(&*table_name, Arc::new(table))?; - let table = wait_for_future(py, self._table(&table_name))?; + let table = wait_for_future(py, self._table(&table_name))??; let df = PyDataFrame::new(table); Ok(df) @@ -650,7 +650,7 @@ impl PySessionContext { .collect(); let result = self.ctx.register_parquet(name, path, options); - wait_for_future(py, result)?; + wait_for_future(py, result)??; Ok(()) } @@ -693,11 +693,11 @@ impl PySessionContext { if path.is_instance_of::() { let paths = path.extract::>()?; let result = self.register_csv_from_multiple_paths(name, paths, options); - wait_for_future(py, result)?; + wait_for_future(py, result)??; } else { let path = path.extract::()?; let result = self.ctx.register_csv(name, &path, options); - wait_for_future(py, result)?; + wait_for_future(py, result)??; } Ok(()) @@ -734,7 +734,7 @@ impl PySessionContext { options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_json(name, path, options); - wait_for_future(py, result)?; + wait_for_future(py, result)??; Ok(()) } @@ -764,7 +764,7 @@ impl PySessionContext { options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_avro(name, path, options); - wait_for_future(py, result)?; + wait_for_future(py, result)??; Ok(()) } @@ -825,9 +825,19 @@ impl PySessionContext { } pub fn table(&self, name: &str, py: Python) -> PyResult { - let x = wait_for_future(py, self.ctx.table(name)) + let res = wait_for_future(py, self.ctx.table(name)) .map_err(|e| PyKeyError::new_err(e.to_string()))?; - Ok(PyDataFrame::new(x)) + match res { + Ok(df) => Ok(PyDataFrame::new(df)), + Err(e) => { + if let datafusion::error::DataFusionError::Plan(msg) = &e { + if msg.contains("No table named") { + return Err(PyKeyError::new_err(msg.to_string())); + } + } + Err(py_datafusion_err(e)) + } + } } pub fn table_exist(&self, name: &str) -> PyDataFusionResult { @@ -865,10 +875,10 @@ impl PySessionContext { let df = if let Some(schema) = schema { options.schema = Some(&schema.0); let result = self.ctx.read_json(path, options); - wait_for_future(py, result)? + wait_for_future(py, result)?? } else { let result = self.ctx.read_json(path, options); - wait_for_future(py, result)? + wait_for_future(py, result)?? }; Ok(PyDataFrame::new(df)) } @@ -915,12 +925,12 @@ impl PySessionContext { let paths = path.extract::>()?; let paths = paths.iter().map(|p| p as &str).collect::>(); let result = self.ctx.read_csv(paths, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)??); Ok(df) } else { let path = path.extract::()?; let result = self.ctx.read_csv(path, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)??); Ok(df) } } @@ -958,7 +968,7 @@ impl PySessionContext { .collect(); let result = self.ctx.read_parquet(path, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)??); Ok(df) } @@ -978,10 +988,10 @@ impl PySessionContext { let df = if let Some(schema) = schema { options.schema = Some(&schema.0); let read_future = self.ctx.read_avro(path, options); - wait_for_future(py, read_future)? + wait_for_future(py, read_future)?? } else { let read_future = self.ctx.read_avro(path, options); - wait_for_future(py, read_future)? + wait_for_future(py, read_future)?? }; Ok(PyDataFrame::new(df)) } @@ -1021,8 +1031,8 @@ impl PySessionContext { let plan = plan.plan.clone(); let fut: JoinHandle> = rt.spawn(async move { plan.execute(part, Arc::new(ctx)) }); - let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; - Ok(PyRecordBatchStream::new(stream?)) + let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???; + Ok(PyRecordBatchStream::new(stream)) } } diff --git a/src/dataframe.rs b/src/dataframe.rs index ece8c4e0f..7711a0782 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -43,7 +43,7 @@ use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use tokio::task::JoinHandle; use crate::catalog::PyTable; -use crate::errors::{py_datafusion_err, PyDataFusionError}; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; @@ -233,7 +233,7 @@ impl PyDataFrame { let (batches, has_more) = wait_for_future( py, collect_record_batches_to_display(self.df.as_ref().clone(), config), - )?; + )??; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -256,7 +256,7 @@ impl PyDataFrame { let (batches, has_more) = wait_for_future( py, collect_record_batches_to_display(self.df.as_ref().clone(), config), - )?; + )??; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -288,7 +288,7 @@ impl PyDataFrame { /// Calculate summary statistics for a DataFrame fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); - let stat_df = wait_for_future(py, df.describe())?; + let stat_df = wait_for_future(py, df.describe())??; Ok(Self::new(stat_df)) } @@ -391,7 +391,7 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect(&self, py: Python) -> PyResult> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect()) + let batches = wait_for_future(py, self.df.as_ref().clone().collect())? .map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 @@ -400,14 +400,14 @@ impl PyDataFrame { /// Cache DataFrame. fn cache(&self, py: Python) -> PyDataFusionResult { - let df = wait_for_future(py, self.df.as_ref().clone().cache())?; + let df = wait_for_future(py, self.df.as_ref().clone().cache())??; Ok(Self::new(df)) } /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned(&self, py: Python) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned()) + let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())? .map_err(PyDataFusionError::from)?; batches @@ -511,7 +511,7 @@ impl PyDataFrame { /// Get the execution plan for this `DataFrame` fn execution_plan(&self, py: Python) -> PyDataFusionResult { - let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?; + let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??; Ok(plan.into()) } @@ -624,7 +624,7 @@ impl PyDataFrame { DataFrameWriteOptions::new(), Some(csv_options), ), - )?; + )??; Ok(()) } @@ -685,7 +685,7 @@ impl PyDataFrame { DataFrameWriteOptions::new(), Option::from(options), ), - )?; + )??; Ok(()) } @@ -697,7 +697,7 @@ impl PyDataFrame { .as_ref() .clone() .write_json(path, DataFrameWriteOptions::new(), None), - )?; + )??; Ok(()) } @@ -720,7 +720,7 @@ impl PyDataFrame { py: Python<'py>, requested_schema: Option>, ) -> PyDataFusionResult> { - let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?; + let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??; let mut schema: Schema = self.df.schema().to_owned().into(); if let Some(schema_capsule) = requested_schema { @@ -753,8 +753,8 @@ impl PyDataFrame { let df = self.df.as_ref().clone(); let fut: JoinHandle> = rt.spawn(async move { df.execute_stream().await }); - let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; - Ok(PyRecordBatchStream::new(stream?)) + let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???; + Ok(PyRecordBatchStream::new(stream)) } fn execute_stream_partitioned(&self, py: Python) -> PyResult> { @@ -763,14 +763,11 @@ impl PyDataFrame { let df = self.df.as_ref().clone(); let fut: JoinHandle>> = rt.spawn(async move { df.execute_stream_partitioned().await }); - let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; + let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })? + .map_err(py_datafusion_err)? + .map_err(py_datafusion_err)?; - match stream { - Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()), - _ => Err(PyValueError::new_err( - "Unable to execute stream partitioned", - )), - } + Ok(stream.into_iter().map(PyRecordBatchStream::new).collect()) } /// Convert to pandas dataframe with pyarrow @@ -815,7 +812,7 @@ impl PyDataFrame { // Executes this DataFrame to get the total number of rows. fn count(&self, py: Python) -> PyDataFusionResult { - Ok(wait_for_future(py, self.df.as_ref().clone().count())?) + Ok(wait_for_future(py, self.df.as_ref().clone().count())??) } /// Fill null values with a specified value for specific columns @@ -841,7 +838,7 @@ impl PyDataFrame { /// Print DataFrame fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches - let batches = wait_for_future(py, df.collect())?; + let batches = wait_for_future(py, df.collect())??; let batches_as_string = pretty::pretty_format_batches(&batches); let result = match batches_as_string { Ok(batch) => format!("DataFrame()\n{batch}"), diff --git a/src/record_batch.rs b/src/record_batch.rs index ec61c263f..a85f05423 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -63,7 +63,7 @@ impl PyRecordBatchStream { impl PyRecordBatchStream { fn next(&mut self, py: Python) -> PyResult { let stream = self.stream.clone(); - wait_for_future(py, next_stream(stream, true)) + wait_for_future(py, next_stream(stream, true))? } fn __next__(&mut self, py: Python) -> PyResult { diff --git a/src/substrait.rs b/src/substrait.rs index 1fefc0bbd..4da3738fb 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -72,7 +72,7 @@ impl PySubstraitSerializer { path: &str, py: Python, ) -> PyDataFusionResult<()> { - wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?; + wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))??; Ok(()) } @@ -94,19 +94,20 @@ impl PySubstraitSerializer { ctx: PySessionContext, py: Python, ) -> PyDataFusionResult { - let proto_bytes: Vec = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))?; + let proto_bytes: Vec = + wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))??; Ok(PyBytes::new(py, &proto_bytes).into()) } #[staticmethod] pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult { - let plan = wait_for_future(py, serializer::deserialize(path))?; + let plan = wait_for_future(py, serializer::deserialize(path))??; Ok(PyPlan { plan: *plan }) } #[staticmethod] pub fn deserialize_bytes(proto_bytes: Vec, py: Python) -> PyDataFusionResult { - let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))?; + let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))??; Ok(PyPlan { plan: *plan }) } } @@ -143,7 +144,7 @@ impl PySubstraitConsumer { ) -> PyDataFusionResult { let session_state = ctx.ctx.state(); let result = consumer::from_substrait_plan(&session_state, &plan.plan); - let logical_plan = wait_for_future(py, result)?; + let logical_plan = wait_for_future(py, result)??; Ok(PyLogicalPlan::new(logical_plan)) } } diff --git a/src/utils.rs b/src/utils.rs index 0a24ab254..90d654385 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,19 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::common::data_type::PyScalarValue; -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::TokioRuntime; -use datafusion::common::ScalarValue; -use datafusion::execution::context::SessionContext; -use datafusion::logical_expr::Volatility; -use pyo3::exceptions::PyValueError; +use crate::{ + common::data_type::PyScalarValue, + errors::{PyDataFusionError, PyDataFusionResult}, + TokioRuntime, +}; +use datafusion::{ + common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility, +}; use pyo3::prelude::*; -use pyo3::types::PyCapsule; -use std::future::Future; -use std::sync::OnceLock; -use tokio::runtime::Runtime; - +use pyo3::{exceptions::PyValueError, types::PyCapsule}; +use std::{future::Future, sync::OnceLock, time::Duration}; +use tokio::{runtime::Runtime, time::sleep}; /// Utility to get the Tokio Runtime from Python #[inline] pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { @@ -47,14 +46,31 @@ pub(crate) fn get_global_ctx() -> &'static SessionContext { CTX.get_or_init(SessionContext::new) } -/// Utility to collect rust futures with GIL released -pub fn wait_for_future(py: Python, f: F) -> F::Output +/// Utility to collect rust futures with GIL released and respond to +/// Python interrupts such as ``KeyboardInterrupt``. If a signal is +/// received while the future is running, the future is aborted and the +/// corresponding Python exception is raised. +pub fn wait_for_future(py: Python, fut: F) -> PyResult where F: Future + Send, F::Output: Send, { let runtime: &Runtime = &get_tokio_runtime().0; - py.allow_threads(|| runtime.block_on(f)) + const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000); + + py.allow_threads(|| { + runtime.block_on(async { + tokio::pin!(fut); + loop { + tokio::select! { + res = &mut fut => break Ok(res), + _ = sleep(INTERVAL_CHECK_SIGNALS) => { + Python::with_gil(|py| py.check_signals())?; + } + } + } + }) + }) } pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult {