Skip to content

Commit ab63c6c

Browse files
authored
Merge pull request #1 from kosiew/pr-1311
Improve KeyboardInterrupt handling during Arrow C stream reads in Python bindings
2 parents 908d6c8 + 5615d83 commit ab63c6c

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

python/tests/test_dataframe.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3286,50 +3286,69 @@ def test_arrow_c_stream_interrupted():
32863286

32873287
reader = pa.RecordBatchReader.from_stream(df)
32883288

3289-
interrupted = False
3290-
interrupt_error = None
3291-
query_started = threading.Event()
3289+
read_started = threading.Event()
3290+
read_exception = []
3291+
read_thread_id = None
32923292
max_wait_time = 5.0
32933293

32943294
def trigger_interrupt():
3295-
start_time = time.time()
3296-
while not query_started.is_set():
3297-
time.sleep(0.1)
3298-
if time.time() - start_time > max_wait_time:
3299-
msg = f"Query did not start within {max_wait_time} seconds"
3300-
raise RuntimeError(msg)
3295+
"""Wait for read to start, then raise KeyboardInterrupt in read thread."""
3296+
if not read_started.wait(timeout=max_wait_time):
3297+
msg = f"Read operation did not start within {max_wait_time} seconds"
3298+
raise RuntimeError(msg)
33013299

3302-
thread_id = threading.main_thread().ident
3303-
if thread_id is None:
3304-
msg = "Cannot get main thread ID"
3300+
if read_thread_id is None:
3301+
msg = "Cannot get read thread ID"
33053302
raise RuntimeError(msg)
33063303

33073304
exception = ctypes.py_object(KeyboardInterrupt)
33083305
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
3309-
ctypes.c_long(thread_id), exception
3306+
ctypes.c_long(read_thread_id), exception
33103307
)
33113308
if res != 1:
33123309
ctypes.pythonapi.PyThreadState_SetAsyncExc(
3313-
ctypes.c_long(thread_id), ctypes.py_object(0)
3310+
ctypes.c_long(read_thread_id), ctypes.py_object(0)
33143311
)
3315-
msg = "Failed to raise KeyboardInterrupt in main thread"
3312+
msg = "Failed to raise KeyboardInterrupt in read thread"
33163313
raise RuntimeError(msg)
33173314

3315+
def read_stream():
3316+
"""Consume the reader, which should be interrupted."""
3317+
nonlocal read_thread_id
3318+
read_thread_id = threading.get_ident()
3319+
try:
3320+
read_started.set()
3321+
result = reader.read_all()
3322+
# If we get here, the read completed without interruption
3323+
read_exception.append(RuntimeError("Read completed without interruption"))
3324+
except KeyboardInterrupt:
3325+
read_exception.append(KeyboardInterrupt)
3326+
except Exception as e:
3327+
read_exception.append(e)
3328+
3329+
read_thread = threading.Thread(target=read_stream)
3330+
read_thread.daemon = True
3331+
read_thread.start()
3332+
33183333
interrupt_thread = threading.Thread(target=trigger_interrupt)
33193334
interrupt_thread.daemon = True
33203335
interrupt_thread.start()
33213336

3322-
try:
3323-
query_started.set()
3324-
# consume the reader which should block and be interrupted
3325-
reader.read_all()
3326-
except KeyboardInterrupt:
3327-
interrupted = True
3328-
except Exception as e: # pragma: no cover - unexpected errors
3329-
interrupt_error = e
3337+
# Wait for the read operation with a timeout
3338+
read_thread.join(timeout=10.0)
33303339

3331-
if not interrupted:
3332-
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
3340+
if read_thread.is_alive():
3341+
pytest.fail("Stream read operation timed out after 10 seconds")
3342+
3343+
# Verify we got the expected KeyboardInterrupt
3344+
if not read_exception:
3345+
pytest.fail("No exception was raised during stream read")
3346+
3347+
# Check if we got KeyboardInterrupt directly or wrapped in another exception
3348+
exception = read_exception[0]
3349+
if not (isinstance(exception, type(KeyboardInterrupt)) or
3350+
"KeyboardInterrupt" in str(exception)):
3351+
pytest.fail(f"Expected KeyboardInterrupt, got: {exception}")
33333352

33343353
interrupt_thread.join(timeout=1.0)
33353354

src/utils.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::ffi::CString;
1819
use std::future::Future;
1920
use std::sync::{Arc, OnceLock};
2021
use std::time::Duration;
@@ -84,7 +85,17 @@ where
8485
tokio::select! {
8586
res = &mut fut => break Ok(res),
8687
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
87-
Python::attach(|py| py.check_signals())?;
88+
Python::attach(|py| {
89+
// Execute a no-op Python statement to trigger signal processing.
90+
// This is necessary because py.check_signals() alone doesn't
91+
// actually check for signals - it only raises an exception if
92+
// a signal was already set during a previous Python API call.
93+
// Running even trivial Python code forces the interpreter to
94+
// process any pending signals (like KeyboardInterrupt).
95+
let code = CString::new("pass").unwrap();
96+
py.run(code.as_c_str(), None, None)?;
97+
py.check_signals()
98+
})?;
8899
}
89100
}
90101
}

0 commit comments

Comments
 (0)