Skip to content

Commit aae758f

Browse files
authored
fix(cubesql): Query cancellation for simple query protocol (#5987)
1 parent 30548db commit aae758f

File tree

4 files changed

+108
-39
lines changed

4 files changed

+108
-39
lines changed

rust/cubesql/Cargo.lock

Lines changed: 30 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/cubesql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pretty_assertions = "1.0.0"
6060
insta = "1.12"
6161
mysql_async = "0.29"
6262
portpicker = "0.1.1"
63-
tokio-postgres = { version = "0.7.6", features = ["with-chrono-0_4", "runtime"] }
63+
tokio-postgres = { version = "0.7.7", features = ["with-chrono-0_4", "runtime"] }
6464
rust_decimal = { version = "1.23", features = ["db-tokio-postgres"] }
6565
pg_interval = "0.4.1"
6666
criterion = { version = "0.4.0", features = ["html_reports"] }

rust/cubesql/cubesql/e2e/tests/postgres.rs

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,22 +264,59 @@ impl PostgresIntegrationTestSuite {
264264
}
265265
}
266266

267-
async fn test_cancel(&self) -> RunResult<()> {
268-
let cancel_token = self.client.cancel_token();
267+
async fn test_cancel_execute_prepared(&self) -> RunResult<()> {
268+
let client = PostgresIntegrationTestSuite::create_client(
269+
format!("host=127.0.0.1 port={} user=test password=test", self.port)
270+
.parse()
271+
.unwrap(),
272+
)
273+
.await;
274+
275+
let cancel_token = client.cancel_token();
269276
let cancel = async move {
270-
tokio::time::sleep(Duration::from_millis(1000)).await;
277+
sleep(Duration::from_millis(1000)).await;
271278

272279
cancel_token.cancel_query(NoTls).await
273280
};
274281

275282
// testing_blocking tables will neven finish. It's a special testing table
276-
let sleep = self
277-
.client
278-
.batch_execute("SELECT * FROM information_schema.testing_blocking");
283+
let sleep = client.batch_execute("SELECT * FROM information_schema.testing_blocking");
284+
285+
match join!(sleep, cancel) {
286+
(Err(ref e), Ok(())) if e.code() == Some(&SqlState::QUERY_CANCELED) => {}
287+
res => panic!(
288+
"unexpected return, prepared must be cancelled, actual: {:?}",
289+
res
290+
),
291+
};
292+
293+
Ok(())
294+
}
295+
296+
async fn test_cancel_simple_query(&self) -> RunResult<()> {
297+
let client = PostgresIntegrationTestSuite::create_client(
298+
format!("host=127.0.0.1 port={} user=test password=test", self.port)
299+
.parse()
300+
.unwrap(),
301+
)
302+
.await;
303+
304+
let cancel_token = client.cancel_token();
305+
let cancel = async move {
306+
sleep(Duration::from_millis(1000)).await;
307+
308+
cancel_token.cancel_query(NoTls).await
309+
};
310+
311+
// testing_blocking tables will neven finish. It's a special testing table
312+
let sleep = client.simple_query("SELECT * FROM information_schema.testing_blocking");
279313

280314
match join!(sleep, cancel) {
281315
(Err(ref e), Ok(())) if e.code() == Some(&SqlState::QUERY_CANCELED) => {}
282-
t => panic!("unexpected return {:?}", t),
316+
(_, err) => panic!(
317+
"unexpected return, simple query must be cancelled, actual: {:?}",
318+
err
319+
),
283320
};
284321

285322
Ok(())
@@ -830,7 +867,8 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
830867
}
831868

832869
async fn run(&mut self) -> RunResult<()> {
833-
self.test_cancel().await?;
870+
self.test_cancel_simple_query().await?;
871+
self.test_cancel_execute_prepared().await?;
834872
self.test_prepare().await?;
835873
self.test_extended_error().await?;
836874
self.test_prepare_empty_query().await?;

rust/cubesql/cubesql/src/sql/postgres/shim.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -931,20 +931,27 @@ impl AsyncPostgresShim {
931931

932932
tokio::select! {
933933
_ = cancel.cancelled() => {
934-
if let Some(qtrace) = qtrace {
935-
qtrace.set_statement_error_message("Execution cancelled by user");
936-
}
937934
self.session.state.end_query();
938935

939936
// We don't return error, because query can contains multiple statements
940937
// then cancel request will cancel only one query
941938
self.write(protocol::ErrorResponse::query_canceled()).await?;
939+
if let Some(qtrace) = qtrace {
940+
qtrace.set_statement_error_message("Execution cancelled by user");
941+
}
942942

943943
Ok(())
944944
},
945945
res = self.process_simple_query(stmt, meta, cancel.clone(), qtrace) => {
946946
self.session.state.end_query();
947947

948+
if cancel.is_cancelled() {
949+
self.write(protocol::ErrorResponse::query_canceled()).await?;
950+
if let Some(qtrace) = qtrace {
951+
qtrace.set_statement_error_message("Execution cancelled by user");
952+
}
953+
}
954+
948955
res
949956
},
950957
}
@@ -1362,25 +1369,30 @@ impl AsyncPostgresShim {
13621369
cancel: CancellationToken,
13631370
) -> Result<(), ConnectionError> {
13641371
let mut writer = BatchWriter::new(portal.get_format());
1365-
let completion = portal.execute(&mut writer, max_rows).await?;
13661372

1367-
if cancel.is_cancelled() {
1368-
return Ok(());
1369-
}
1370-
1371-
// Special handling for special queries, such as DISCARD ALL.
1372-
if let Some(description) = portal.get_description()? {
1373-
match description.len() {
1374-
0 => self.write(protocol::NoData::new()).await?,
1375-
_ => self.write(description).await?,
1376-
};
1377-
}
1373+
tokio::select! {
1374+
_ = cancel.cancelled() => {
1375+
// TODO: Cancellation handling via errors?
1376+
return Ok(());
1377+
},
1378+
res = portal.execute(&mut writer, max_rows) => {
1379+
let completion = res?;
1380+
1381+
// Special handling for special queries, such as DISCARD ALL.
1382+
if let Some(description) = portal.get_description()? {
1383+
match description.len() {
1384+
0 => self.write(protocol::NoData::new()).await?,
1385+
_ => self.write(description).await?,
1386+
};
1387+
}
13781388

1379-
if writer.has_data() {
1380-
buffer::write_direct(&mut self.socket, writer).await?;
1381-
};
1389+
if writer.has_data() {
1390+
buffer::write_direct(&mut self.socket, writer).await?;
1391+
};
13821392

1383-
self.write_completion(completion).await
1393+
self.write_completion(completion).await
1394+
}
1395+
}
13841396
}
13851397

13861398
/// Pipeline of Execution

0 commit comments

Comments
 (0)