Skip to content

Commit f485318

Browse files
authored
feat: flight sql support send progress into client (#10908)
1 parent 7cbb8fd commit f485318

File tree

10 files changed

+135
-26
lines changed

10 files changed

+135
-26
lines changed

src/query/catalog/src/table_context.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ pub trait TableContext: Send + Sync {
7878
/// This method builds a `dyn Table`, which provides table specific io methods the plan needs.
7979
fn build_table_from_source_plan(&self, plan: &DataSourcePlan) -> Result<Arc<dyn Table>>;
8080

81+
fn incr_total_scan_value(&self, value: ProgressValues);
82+
fn get_total_scan_value(&self) -> ProgressValues;
83+
8184
fn get_scan_progress(&self) -> Arc<Progress>;
8285
fn get_scan_progress_value(&self) -> ProgressValues;
8386
fn get_write_progress(&self) -> Arc<Progress>;

src/query/service/src/servers/flight_sql/flight_sql_service/query.rs

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::sync::atomic::AtomicBool;
16+
use std::sync::atomic::Ordering;
1517
use std::sync::Arc;
1618

1719
use arrow_flight::FlightData;
1820
use arrow_flight::SchemaAsIpc;
1921
use arrow_ipc::writer;
2022
use arrow_ipc::writer::IpcWriteOptions;
2123
use arrow_schema::Schema as ArrowSchema;
22-
use async_stream::stream;
24+
use common_base::base::tokio;
2325
use common_exception::ErrorCode;
2426
use common_exception::Result;
2527
use common_expression::DataBlock;
@@ -28,7 +30,10 @@ use common_sql::plans::Plan;
2830
use common_sql::PlanExtras;
2931
use common_sql::Planner;
3032
use common_storages_fuse::TableContext;
31-
use futures_util::StreamExt;
33+
use futures::Stream;
34+
use futures::StreamExt;
35+
use serde::Deserialize;
36+
use serde::Serialize;
3237
use tonic::Status;
3338

3439
use super::status;
@@ -37,6 +42,9 @@ use super::FlightSqlServiceImpl;
3742
use crate::interpreters::InterpreterFactory;
3843
use crate::sessions::Session;
3944

45+
/// A app_metakey which indicates the data is a progress type
46+
static H_PROGRESS: u8 = 0x01;
47+
4048
impl FlightSqlServiceImpl {
4149
pub(crate) fn schema_to_flight_data(data_schema: DataSchema) -> FlightData {
4250
let arrow_schema = ArrowSchema::from(&data_schema);
@@ -98,49 +106,109 @@ impl FlightSqlServiceImpl {
98106
Ok(affected_rows as i64)
99107
}
100108

101-
#[async_backtrace::framed]
102109
pub(super) async fn execute_query(
103110
&self,
104111
session: Arc<Session>,
105112
plan: &Plan,
106113
plan_extras: &PlanExtras,
107114
) -> Result<DoGetStream> {
115+
let is_native_client = session.get_status().read().is_native_client;
116+
108117
let context = session
109118
.create_query_context()
110119
.await
111120
.map_err(|e| status!("Could not create_query_context", e))?;
112121

113122
context.attach_query_str(plan.to_string(), plan_extras.stament.to_mask_sql());
114123
let interpreter = InterpreterFactory::get(context.clone(), plan).await?;
124+
115125
let data_schema = interpreter.schema();
116-
let schema_flight_data = Self::schema_to_flight_data((*data_schema).clone());
126+
let data_stream = interpreter.execute(context.clone()).await?;
117127

118-
let mut data_stream = interpreter.execute(context.clone()).await?;
128+
let is_finished = Arc::new(AtomicBool::new(false));
129+
let is_finished_clone = is_finished.clone();
130+
let (sender, receiver) = tokio::sync::mpsc::channel(2);
131+
let _ = sender
132+
.send(Ok(Self::schema_to_flight_data((*data_schema).clone())))
133+
.await;
134+
135+
let s1 = sender.clone();
136+
tokio::spawn(async move {
137+
let mut data_stream = data_stream;
119138

120-
let stream = stream! {
121-
yield Ok(schema_flight_data);
122139
while let Some(block) = data_stream.next().await {
123140
match block {
124141
Ok(block) => {
125-
match Self::block_to_flight_data(block, &data_schema) {
126-
Ok(flight_data) => {
127-
yield Ok(flight_data)
128-
}
129-
Err(err) => {
130-
yield Err(status!("Could not convert batches", err))
131-
}
132-
}
142+
let res =
143+
match FlightSqlServiceImpl::block_to_flight_data(block, &data_schema) {
144+
Ok(flight_data) => Ok(flight_data),
145+
Err(err) => Err(status!("Could not convert batches", err)),
146+
};
147+
148+
let _ = s1.send(res).await;
133149
}
134150
Err(err) => {
135-
yield Err(status!("Could not convert batches", err))
151+
let _ = s1
152+
.send(Err(status!("Could not convert batches", err)))
153+
.await;
136154
}
137-
};
155+
}
138156
}
157+
is_finished_clone.store(true, Ordering::SeqCst);
158+
});
139159

140-
// to hold session ref until stream is all consumed
141-
let _ = session.get_id();
142-
};
160+
if is_native_client {
161+
tokio::spawn(async move {
162+
let total_scan_value = context.get_total_scan_value();
163+
let mut current_scan_value = context.get_scan_progress_value();
143164

144-
Ok(Box::pin(stream))
165+
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(20));
166+
while !is_finished.load(Ordering::SeqCst) {
167+
interval.tick().await;
168+
169+
let progress = context.get_scan_progress_value();
170+
if progress.rows == current_scan_value.rows {
171+
continue;
172+
}
173+
current_scan_value = progress;
174+
175+
let progress = ProgressValue {
176+
total_rows: total_scan_value.rows,
177+
total_bytes: total_scan_value.bytes,
178+
179+
read_rows: current_scan_value.rows,
180+
read_bytes: current_scan_value.bytes,
181+
};
182+
183+
let progress = serde_json::to_vec(&progress).unwrap();
184+
let progress_flight_data = FlightData {
185+
app_metadata: vec![H_PROGRESS].into(),
186+
data_body: progress.into(),
187+
..Default::default()
188+
};
189+
let _ = sender.send(Ok(progress_flight_data)).await;
190+
}
191+
});
192+
}
193+
194+
fn receiver_to_stream<T>(
195+
receiver: tokio::sync::mpsc::Receiver<T>,
196+
) -> impl Stream<Item = T> {
197+
futures::stream::unfold(receiver, |mut receiver| async {
198+
receiver.recv().await.map(|value| (value, receiver))
199+
})
200+
}
201+
202+
let st = receiver_to_stream(receiver);
203+
Ok(Box::pin(st))
145204
}
146205
}
206+
207+
#[derive(Serialize, Deserialize, Debug)]
208+
struct ProgressValue {
209+
pub total_rows: usize,
210+
pub total_bytes: usize,
211+
212+
pub read_rows: usize,
213+
pub read_bytes: usize,
214+
}

src/query/service/src/servers/flight_sql/flight_sql_service/service.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
120120
Status,
121121
> {
122122
let remote_addr = request.remote_addr();
123+
123124
let (user, password) = FlightSqlServiceImpl::get_user_password(request.metadata())
124125
.map_err(Status::invalid_argument)?;
125126
let session = FlightSqlServiceImpl::auth_user_password(user, password, remote_addr).await?;
@@ -136,6 +137,10 @@ impl FlightSqlService for FlightSqlServiceImpl {
136137
let metadata = MetadataValue::try_from(str)
137138
.map_err(|_| Status::internal("authorization not parsable"))?;
138139
resp.metadata_mut().insert("authorization", metadata);
140+
141+
session.get_status().write().is_native_client =
142+
FlightSqlServiceImpl::get_header_value(request.metadata(), "bendsql").is_some();
143+
139144
self.sessions.insert(token, session);
140145
Ok(resp)
141146
}

src/query/service/src/servers/flight_sql/flight_sql_service/session.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ impl FlightSqlServiceImpl {
5454
}
5555
}
5656

57+
pub(super) fn get_header_value(metadata: &MetadataMap, key: &str) -> Option<String> {
58+
metadata
59+
.get(key)
60+
.and_then(|v| v.to_str().ok())
61+
.map(|v| v.to_string())
62+
}
63+
5764
pub(super) fn get_user_password(metadata: &MetadataMap) -> Result<(String, String), String> {
5865
let basic = "Basic ";
59-
let authorization = metadata
60-
.get("authorization")
61-
.ok_or("authorization field not present")?
62-
.to_str()
63-
.map_err(|e| format!("authorization not parsable: {}", e))?;
66+
let authorization = Self::get_header_value(metadata, "authorization")
67+
.ok_or_else(|| "authorization not parsable".to_string())?;
68+
6469
if !authorization.starts_with(basic) {
6570
return Err(format!("Auth type not implemented: {authorization}"));
6671
}

src/query/service/src/sessions/query_ctx.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ impl TableContext for QueryContext {
236236
}
237237
}
238238

239+
fn incr_total_scan_value(&self, value: ProgressValues) {
240+
self.shared.total_scan_values.as_ref().incr(&value);
241+
}
242+
243+
fn get_total_scan_value(&self) -> ProgressValues {
244+
self.shared.total_scan_values.as_ref().get_values()
245+
}
246+
239247
fn get_scan_progress(&self) -> Arc<Progress> {
240248
self.shared.scan_progress.clone()
241249
}

src/query/service/src/sessions/query_ctx_shared.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ type DatabaseAndTable = (String, String, String);
5555
/// FROM table_name_4;
5656
/// For each subquery, they will share a runtime, session, progress, init_query_id
5757
pub struct QueryContextShared {
58+
/// total_scan_values for scan stats
59+
pub(in crate::sessions) total_scan_values: Arc<Progress>,
5860
/// scan_progress for scan metrics of datablocks (uncompressed)
5961
pub(in crate::sessions) scan_progress: Arc<Progress>,
6062
/// write_progress for write/commit metrics of datablocks (uncompressed)
@@ -96,6 +98,7 @@ impl QueryContextShared {
9698
catalog_manager: CatalogManager::instance(),
9799
data_operator: DataOperator::instance(),
98100
init_query_id: Arc::new(RwLock::new(Uuid::new_v4().to_string())),
101+
total_scan_values: Arc::new(Progress::create()),
99102
scan_progress: Arc::new(Progress::create()),
100103
result_progress: Arc::new(Progress::create()),
101104
write_progress: Arc::new(Progress::create()),

src/query/service/src/sessions/session_status.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::time::Instant;
1717
pub struct SessionStatus {
1818
pub session_started_at: Instant,
1919
pub last_query_finished_at: Option<Instant>,
20+
pub is_native_client: bool,
2021
}
2122

2223
impl SessionStatus {
@@ -35,6 +36,7 @@ impl Default for SessionStatus {
3536
SessionStatus {
3637
session_started_at: Instant::now(),
3738
last_query_finished_at: None,
39+
is_native_client: false,
3840
}
3941
}
4042
}

src/query/service/src/stream/processor_executor_stream.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ impl PullingExecutorStream {
3636
impl Stream for PullingExecutorStream {
3737
type Item = Result<DataBlock>;
3838

39+
// The ctx can't be wake up, so we can't return Poll::Pending here
3940
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
4041
let self_ = Pin::get_mut(self);
4142
match self_.executor.pull_data() {

src/query/service/tests/it/storages/fuse/operations/commit.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,14 @@ impl TableContext for CtxDelegation {
371371
todo!()
372372
}
373373

374+
fn incr_total_scan_value(&self, _value: ProgressValues) {
375+
todo!()
376+
}
377+
378+
fn get_total_scan_value(&self) -> ProgressValues {
379+
todo!()
380+
}
381+
374382
fn get_scan_progress(&self) -> Arc<Progress> {
375383
todo!()
376384
}

src/query/sql/src/executor/table_read_plan.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use std::collections::BTreeMap;
1616
use std::sync::Arc;
1717

18+
use common_base::base::ProgressValues;
1819
use common_catalog::plan::DataSourcePlan;
1920
use common_catalog::plan::InternalColumn;
2021
use common_catalog::plan::PartStatistics;
@@ -73,6 +74,11 @@ impl ToReadDataSourcePlan for dyn Table {
7374
self.read_partitions(ctx.clone(), push_downs.clone()).await
7475
}?;
7576

77+
ctx.incr_total_scan_value(ProgressValues {
78+
rows: statistics.read_rows,
79+
bytes: statistics.read_bytes,
80+
});
81+
7682
// We need the partition sha256 to specify the result cache.
7783
if ctx.get_settings().get_enable_query_result_cache()? {
7884
let sha = parts.compute_sha256()?;

0 commit comments

Comments
 (0)