1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ use std:: sync:: atomic:: AtomicBool ;
16+ use std:: sync:: atomic:: Ordering ;
1517use std:: sync:: Arc ;
1618
1719use arrow_flight:: FlightData ;
1820use arrow_flight:: SchemaAsIpc ;
1921use arrow_ipc:: writer;
2022use arrow_ipc:: writer:: IpcWriteOptions ;
2123use arrow_schema:: Schema as ArrowSchema ;
22- use async_stream :: stream ;
24+ use common_base :: base :: tokio ;
2325use common_exception:: ErrorCode ;
2426use common_exception:: Result ;
2527use common_expression:: DataBlock ;
@@ -28,7 +30,10 @@ use common_sql::plans::Plan;
2830use common_sql:: PlanExtras ;
2931use common_sql:: Planner ;
3032use common_storages_fuse:: TableContext ;
31- use futures_util:: StreamExt ;
33+ use futures:: Stream ;
34+ use futures:: StreamExt ;
35+ use serde:: Deserialize ;
36+ use serde:: Serialize ;
3237use tonic:: Status ;
3338
3439use super :: status;
@@ -37,6 +42,9 @@ use super::FlightSqlServiceImpl;
3742use crate :: interpreters:: InterpreterFactory ;
3843use crate :: sessions:: Session ;
3944
45+ /// A app_metakey which indicates the data is a progress type
46+ static H_PROGRESS : u8 = 0x01 ;
47+
4048impl FlightSqlServiceImpl {
4149 pub ( crate ) fn schema_to_flight_data ( data_schema : DataSchema ) -> FlightData {
4250 let arrow_schema = ArrowSchema :: from ( & data_schema) ;
@@ -98,49 +106,109 @@ impl FlightSqlServiceImpl {
98106 Ok ( affected_rows as i64 )
99107 }
100108
101- #[ async_backtrace:: framed]
102109 pub ( super ) async fn execute_query (
103110 & self ,
104111 session : Arc < Session > ,
105112 plan : & Plan ,
106113 plan_extras : & PlanExtras ,
107114 ) -> Result < DoGetStream > {
115+ let is_native_client = session. get_status ( ) . read ( ) . is_native_client ;
116+
108117 let context = session
109118 . create_query_context ( )
110119 . await
111120 . map_err ( |e| status ! ( "Could not create_query_context" , e) ) ?;
112121
113122 context. attach_query_str ( plan. to_string ( ) , plan_extras. stament . to_mask_sql ( ) ) ;
114123 let interpreter = InterpreterFactory :: get ( context. clone ( ) , plan) . await ?;
124+
115125 let data_schema = interpreter. schema ( ) ;
116- let schema_flight_data = Self :: schema_to_flight_data ( ( * data_schema ) . clone ( ) ) ;
126+ let data_stream = interpreter . execute ( context . clone ( ) ) . await ? ;
117127
118- let mut data_stream = interpreter. execute ( context. clone ( ) ) . await ?;
128+ let is_finished = Arc :: new ( AtomicBool :: new ( false ) ) ;
129+ let is_finished_clone = is_finished. clone ( ) ;
130+ let ( sender, receiver) = tokio:: sync:: mpsc:: channel ( 2 ) ;
131+ let _ = sender
132+ . send ( Ok ( Self :: schema_to_flight_data ( ( * data_schema) . clone ( ) ) ) )
133+ . await ;
134+
135+ let s1 = sender. clone ( ) ;
136+ tokio:: spawn ( async move {
137+ let mut data_stream = data_stream;
119138
120- let stream = stream ! {
121- yield Ok ( schema_flight_data) ;
122139 while let Some ( block) = data_stream. next ( ) . await {
123140 match block {
124141 Ok ( block) => {
125- match Self :: block_to_flight_data( block, & data_schema) {
126- Ok ( flight_data) => {
127- yield Ok ( flight_data)
128- }
129- Err ( err) => {
130- yield Err ( status!( "Could not convert batches" , err) )
131- }
132- }
142+ let res =
143+ match FlightSqlServiceImpl :: block_to_flight_data ( block, & data_schema) {
144+ Ok ( flight_data) => Ok ( flight_data) ,
145+ Err ( err) => Err ( status ! ( "Could not convert batches" , err) ) ,
146+ } ;
147+
148+ let _ = s1. send ( res) . await ;
133149 }
134150 Err ( err) => {
135- yield Err ( status!( "Could not convert batches" , err) )
151+ let _ = s1
152+ . send ( Err ( status ! ( "Could not convert batches" , err) ) )
153+ . await ;
136154 }
137- } ;
155+ }
138156 }
157+ is_finished_clone. store ( true , Ordering :: SeqCst ) ;
158+ } ) ;
139159
140- // to hold session ref until stream is all consumed
141- let _ = session. get_id( ) ;
142- } ;
160+ if is_native_client {
161+ tokio:: spawn ( async move {
162+ let total_scan_value = context. get_total_scan_value ( ) ;
163+ let mut current_scan_value = context. get_scan_progress_value ( ) ;
143164
144- Ok ( Box :: pin ( stream) )
165+ let mut interval = tokio:: time:: interval ( tokio:: time:: Duration :: from_millis ( 20 ) ) ;
166+ while !is_finished. load ( Ordering :: SeqCst ) {
167+ interval. tick ( ) . await ;
168+
169+ let progress = context. get_scan_progress_value ( ) ;
170+ if progress. rows == current_scan_value. rows {
171+ continue ;
172+ }
173+ current_scan_value = progress;
174+
175+ let progress = ProgressValue {
176+ total_rows : total_scan_value. rows ,
177+ total_bytes : total_scan_value. bytes ,
178+
179+ read_rows : current_scan_value. rows ,
180+ read_bytes : current_scan_value. bytes ,
181+ } ;
182+
183+ let progress = serde_json:: to_vec ( & progress) . unwrap ( ) ;
184+ let progress_flight_data = FlightData {
185+ app_metadata : vec ! [ H_PROGRESS ] . into ( ) ,
186+ data_body : progress. into ( ) ,
187+ ..Default :: default ( )
188+ } ;
189+ let _ = sender. send ( Ok ( progress_flight_data) ) . await ;
190+ }
191+ } ) ;
192+ }
193+
194+ fn receiver_to_stream < T > (
195+ receiver : tokio:: sync:: mpsc:: Receiver < T > ,
196+ ) -> impl Stream < Item = T > {
197+ futures:: stream:: unfold ( receiver, |mut receiver| async {
198+ receiver. recv ( ) . await . map ( |value| ( value, receiver) )
199+ } )
200+ }
201+
202+ let st = receiver_to_stream ( receiver) ;
203+ Ok ( Box :: pin ( st) )
145204 }
146205}
206+
207+ #[ derive( Serialize , Deserialize , Debug ) ]
208+ struct ProgressValue {
209+ pub total_rows : usize ,
210+ pub total_bytes : usize ,
211+
212+ pub read_rows : usize ,
213+ pub read_bytes : usize ,
214+ }
0 commit comments