Skip to content

Commit d80f21d

Browse files
committed
feat(rust/core)!: define cancellation in a sensible way
Closes #3454.
1 parent 16ceb85 commit d80f21d

File tree

9 files changed

+173
-82
lines changed

9 files changed

+173
-82
lines changed

rust/core/src/sync.rs

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ pub trait Optionable {
4444
fn get_option_double(&self, key: Self::Option) -> Result<f64>;
4545
}
4646

47+
/// A handle to cancel an in-progress operation on a connection.
48+
///
49+
/// This is a separated handle because otherwise it would be impossible to
50+
/// call a `cancel` method on a connection or statement itself.
51+
pub trait CancelHandle: Send {
52+
/// Cancel the in-progress operation on a connection.
53+
fn try_cancel(&self) -> Result<()>;
54+
}
55+
56+
/// A cancellation handle that does nothing (because cancellation is unsupported).
57+
pub struct NoOpCancellationHandle;
58+
59+
impl CancelHandle for NoOpCancellationHandle {
60+
fn try_cancel(&self) -> Result<()> {
61+
Ok(())
62+
}
63+
}
64+
4765
/// A handle to an ADBC driver.
4866
pub trait Driver {
4967
type DatabaseType: Database;
@@ -76,6 +94,11 @@ pub trait Database: Optionable<Option = OptionDatabase> {
7694
&self,
7795
opts: impl IntoIterator<Item = (options::OptionConnection, OptionValue)>,
7896
) -> Result<Self::ConnectionType>;
97+
98+
/// Get a handle to cancel operations on this database.
99+
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
100+
Box::new(NoOpCancellationHandle {})
101+
}
79102
}
80103

81104
/// A handle to an ADBC connection.
@@ -94,8 +117,10 @@ pub trait Connection: Optionable<Option = OptionConnection> {
94117
/// Allocate and initialize a new statement.
95118
fn new_statement(&mut self) -> Result<Self::StatementType>;
96119

97-
/// Cancel the in-progress operation on a connection.
98-
fn cancel(&mut self) -> Result<()>;
120+
/// Get a handle to cancel operations on this connection.
121+
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
122+
Box::new(NoOpCancellationHandle {})
123+
}
99124

100125
/// Get metadata about the database/driver.
101126
///
@@ -455,13 +480,15 @@ pub trait Statement: Optionable<Option = OptionStatement> {
455480
/// expected to be executed repeatedly, call [Statement::prepare] first.
456481
fn set_substrait_plan(&mut self, plan: impl AsRef<[u8]>) -> Result<()>;
457482

458-
/// Cancel execution of an in-progress query.
483+
/// Get a handle to cancel operations on this statement.
459484
///
460-
/// This can be called during [Statement::execute] (or similar), or while
461-
/// consuming a result set returned from such.
485+
/// The resulting handle can be called during [Statement::execute] (or
486+
/// similar), or while consuming a result set returned from such.
462487
///
463488
/// # Since
464489
///
465490
/// ADBC API revision 1.1.0
466-
fn cancel(&mut self) -> Result<()>;
491+
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
492+
Box::new(NoOpCancellationHandle {})
493+
}
467494
}

rust/driver/datafusion/src/lib.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -735,10 +735,6 @@ impl Connection for DataFusionConnection {
735735
})
736736
}
737737

738-
fn cancel(&mut self) -> adbc_core::error::Result<()> {
739-
todo!()
740-
}
741-
742738
fn get_info(
743739
&self,
744740
codes: Option<std::collections::HashSet<adbc_core::options::InfoCode>>,
@@ -984,10 +980,6 @@ impl Statement for DataFusionStatement {
984980
self.substrait_plan = Some(Plan::decode(plan.as_ref()).unwrap());
985981
Ok(())
986982
}
987-
988-
fn cancel(&mut self) -> adbc_core::error::Result<()> {
989-
todo!()
990-
}
991983
}
992984

993985
#[cfg(feature = "ffi")]

rust/driver/dummy/src/lib.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -294,16 +294,24 @@ impl Connection for DummyConnection {
294294
Ok(Self::StatementType::default())
295295
}
296296

297-
// This method is used to test that errors round-trip correctly.
298-
fn cancel(&mut self) -> Result<()> {
299-
let mut error = Error::with_message_and_status("message", Status::Cancelled);
300-
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
301-
error.sqlstate = [1, 2, 3, 4, 5];
302-
error.details = Some(vec![
303-
("key1".into(), b"AAA".into()),
304-
("key2".into(), b"ZZZZZ".into()),
305-
]);
306-
Err(error)
297+
/// This method is used to test that errors round-trip correctly.
298+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
299+
struct CancelHandle;
300+
301+
impl adbc_core::CancelHandle for CancelHandle {
302+
fn try_cancel(&self) -> Result<()> {
303+
let mut error = Error::with_message_and_status("message", Status::Cancelled);
304+
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
305+
error.sqlstate = [1, 2, 3, 4, 5];
306+
error.details = Some(vec![
307+
("key1".into(), b"AAA".into()),
308+
("key2".into(), b"ZZZZZ".into()),
309+
]);
310+
Err(error)
311+
}
312+
}
313+
314+
Box::new(CancelHandle)
307315
}
308316

309317
fn commit(&mut self) -> Result<()> {
@@ -854,10 +862,6 @@ impl Statement for DummyStatement {
854862
Ok(())
855863
}
856864

857-
fn cancel(&mut self) -> Result<()> {
858-
Ok(())
859-
}
860-
861865
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
862866
maybe_panic("StatementExecuteQuery");
863867
let batch = get_table_data();

rust/driver/dummy/tests/driver_exporter_dummy.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,14 @@ fn test_connection_get_info() {
366366

367367
#[test]
368368
fn test_connection_cancel() {
369-
let (_, _, mut exported_connection, _) = get_exported();
370-
let (_, _, mut native_connection, _) = get_native();
369+
let (_, _, exported_connection, _) = get_exported();
370+
let (_, _, native_connection, _) = get_native();
371+
372+
let exported_handle = exported_connection.get_cancel_handle();
373+
let native_handle = native_connection.get_cancel_handle();
371374

372-
let exported_error = exported_connection.cancel().unwrap_err();
373-
let native_error = native_connection.cancel().unwrap_err();
375+
let exported_error = exported_handle.try_cancel().unwrap_err();
376+
let native_error = native_handle.try_cancel().unwrap_err();
374377

375378
assert_eq!(exported_error, native_error);
376379
}
@@ -569,11 +572,14 @@ fn test_statement_bind_stream() {
569572

570573
#[test]
571574
fn test_statement_cancel() {
572-
let (_, _, _, mut exported_statement) = get_exported();
573-
let (_, _, _, mut native_statement) = get_native();
575+
let (_, _, _, exported_statement) = get_exported();
576+
let (_, _, _, native_statement) = get_native();
577+
578+
let exported_handle = exported_statement.get_cancel_handle();
579+
let native_handle = native_statement.get_cancel_handle();
574580

575-
exported_statement.cancel().unwrap();
576-
native_statement.cancel().unwrap();
581+
exported_handle.try_cancel().unwrap();
582+
native_handle.try_cancel().unwrap();
577583
}
578584

579585
#[test]

rust/driver/snowflake/src/connection.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ impl adbc_core::Connection for Connection {
7070
self.0.new_statement().map(Statement)
7171
}
7272

73-
fn cancel(&mut self) -> Result<()> {
74-
self.0.cancel()
73+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
74+
self.0.get_cancel_handle()
7575
}
7676

7777
fn get_info(

rust/driver/snowflake/src/statement.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl adbc_core::Statement for Statement {
9696
self.0.set_substrait_plan(plan)
9797
}
9898

99-
fn cancel(&mut self) -> Result<()> {
100-
self.0.cancel()
99+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
100+
self.0.get_cancel_handle()
101101
}
102102
}

rust/driver_manager/src/lib.rs

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,31 @@ pub struct ManagedConnection {
10271027
inner: Arc<ManagedConnectionInner>,
10281028
}
10291029

1030+
struct ConnectionCancelHandle {
1031+
inner: std::sync::Weak<ManagedConnectionInner>,
1032+
}
1033+
1034+
impl adbc_core::CancelHandle for ConnectionCancelHandle {
1035+
fn try_cancel(&self) -> Result<()> {
1036+
if let Some(inner) = self.inner.upgrade() {
1037+
if let AdbcVersion::V100 = inner.database.driver.version {
1038+
return Err(Error::with_message_and_status(
1039+
ERR_CANCEL_UNSUPPORTED,
1040+
Status::NotImplemented,
1041+
));
1042+
}
1043+
let driver = &inner.database.driver.driver;
1044+
let mut connection = inner.connection.lock().unwrap();
1045+
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1046+
let method = driver_method!(driver, ConnectionCancel);
1047+
let status = unsafe { method(connection.deref_mut(), &mut error) };
1048+
check_status(status, error)
1049+
} else {
1050+
Ok(())
1051+
}
1052+
}
1053+
}
1054+
10301055
impl ManagedConnection {
10311056
fn ffi_driver(&self) -> &adbc_ffi::FFI_AdbcDriver {
10321057
&self.inner.database.driver.driver
@@ -1125,19 +1150,10 @@ impl Connection for ManagedConnection {
11251150
Ok(Self::StatementType { inner })
11261151
}
11271152

1128-
fn cancel(&mut self) -> Result<()> {
1129-
if let AdbcVersion::V100 = self.driver_version() {
1130-
return Err(Error::with_message_and_status(
1131-
ERR_CANCEL_UNSUPPORTED,
1132-
Status::NotImplemented,
1133-
));
1134-
}
1135-
let driver = self.ffi_driver();
1136-
let mut connection = self.inner.connection.lock().unwrap();
1137-
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1138-
let method = driver_method!(driver, ConnectionCancel);
1139-
let status = unsafe { method(connection.deref_mut(), &mut error) };
1140-
check_status(status, error)
1153+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
1154+
Box::new(ConnectionCancelHandle {
1155+
inner: Arc::downgrade(&self.inner),
1156+
})
11411157
}
11421158

11431159
fn commit(&mut self) -> Result<()> {
@@ -1401,6 +1417,31 @@ impl ManagedStatement {
14011417
}
14021418
}
14031419

1420+
struct StatementCancelHandle {
1421+
inner: std::sync::Weak<ManagedStatementInner>,
1422+
}
1423+
1424+
impl adbc_core::CancelHandle for StatementCancelHandle {
1425+
fn try_cancel(&self) -> Result<()> {
1426+
if let Some(inner) = self.inner.upgrade() {
1427+
if let AdbcVersion::V100 = inner.connection.database.driver.version {
1428+
return Err(Error::with_message_and_status(
1429+
ERR_CANCEL_UNSUPPORTED,
1430+
Status::NotImplemented,
1431+
));
1432+
}
1433+
let driver = &inner.connection.database.driver.driver;
1434+
let mut statement = inner.statement.lock().unwrap();
1435+
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1436+
let method = driver_method!(driver, StatementCancel);
1437+
let status = unsafe { method(statement.deref_mut(), &mut error) };
1438+
check_status(status, error)
1439+
} else {
1440+
Ok(())
1441+
}
1442+
}
1443+
}
1444+
14041445
impl Statement for ManagedStatement {
14051446
fn bind(&mut self, batch: RecordBatch) -> Result<()> {
14061447
let driver = self.ffi_driver();
@@ -1425,19 +1466,10 @@ impl Statement for ManagedStatement {
14251466
Ok(())
14261467
}
14271468

1428-
fn cancel(&mut self) -> Result<()> {
1429-
if let AdbcVersion::V100 = self.driver_version() {
1430-
return Err(Error::with_message_and_status(
1431-
ERR_CANCEL_UNSUPPORTED,
1432-
Status::NotImplemented,
1433-
));
1434-
}
1435-
let driver = self.ffi_driver();
1436-
let mut statement = self.inner.statement.lock().unwrap();
1437-
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
1438-
let method = driver_method!(driver, StatementCancel);
1439-
let status = unsafe { method(statement.deref_mut(), &mut error) };
1440-
check_status(status, error)
1469+
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
1470+
Box::new(StatementCancelHandle {
1471+
inner: Arc::downgrade(&self.inner),
1472+
})
14411473
}
14421474

14431475
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {

rust/driver_manager/tests/driver_manager_sqlite.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ fn test_connection_get_option() {
128128
fn test_connection_cancel() {
129129
let mut driver = get_driver();
130130
let database = get_database(&mut driver);
131-
let mut connection = database.new_connection().unwrap();
131+
let connection = database.new_connection().unwrap();
132132

133-
let error = connection.cancel().unwrap_err();
133+
let handle = connection.get_cancel_handle();
134+
let error = handle.try_cancel().unwrap_err();
134135
assert_eq!(error.status, Status::NotImplemented);
135136
}
136137

@@ -285,9 +286,10 @@ fn test_statement_cancel() {
285286
let mut driver = get_driver();
286287
let database = get_database(&mut driver);
287288
let mut connection = database.new_connection().unwrap();
288-
let mut statement = connection.new_statement().unwrap();
289+
let statement = connection.new_statement().unwrap();
289290

290-
let error = statement.cancel().unwrap_err();
291+
let handle = statement.get_cancel_handle();
292+
let error = handle.try_cancel().unwrap_err();
291293
assert_eq!(error.status, Status::NotImplemented);
292294
}
293295

0 commit comments

Comments
 (0)