@@ -48,6 +48,7 @@ use datafusion::{
4848use datafusion_substrait:: {
4949 logical_plan:: consumer:: from_substrait_plan, serializer:: deserialize_bytes,
5050} ;
51+
5152use futures:: { Stream , StreamExt , TryStreamExt } ;
5253use log:: info;
5354use once_cell:: sync:: Lazy ;
@@ -56,6 +57,7 @@ use prost::Message;
5657use tonic:: transport:: Server ;
5758use tonic:: { Request , Response , Status , Streaming } ;
5859
60+ use super :: config:: FlightSqlServiceConfig ;
5961use super :: session:: { SessionStateProvider , StaticSessionStateProvider } ;
6062use super :: state:: { CommandTicket , QueryHandle } ;
6163
@@ -65,6 +67,7 @@ type Result<T, E = Status> = std::result::Result<T, E>;
6567pub struct FlightSqlService {
6668 provider : Box < dyn SessionStateProvider > ,
6769 sql_options : Option < SQLOptions > ,
70+ config : FlightSqlServiceConfig ,
6871}
6972
7073impl FlightSqlService {
@@ -78,9 +81,15 @@ impl FlightSqlService {
7881 Self {
7982 provider,
8083 sql_options : None ,
84+ config : FlightSqlServiceConfig :: default ( ) ,
8185 }
8286 }
8387
88+ /// Replaces the FlightSqlServiceConfig with the provided config.
89+ pub fn with_config ( self , config : FlightSqlServiceConfig ) -> Self {
90+ Self { config, ..self }
91+ }
92+
8493 /// Replaces the sql_options with the provided options.
8594 /// These options are used to verify all SQL queries.
8695 /// When None the default [`SQLOptions`] are used.
@@ -303,7 +312,7 @@ impl ArrowFlightSqlService for FlightSqlService {
303312 . await
304313 . map_err ( df_error_to_status) ?;
305314
306- let dataset_schema = get_schema_for_plan ( & plan) ;
315+ let dataset_schema = get_schema_for_plan ( & plan, self . config . schema_with_metadata ) ;
307316
308317 // Form the response ticket (that the client will pass back to DoGet)
309318 let ticket = CommandTicket :: new ( sql:: Command :: CommandStatementQuery ( query) )
@@ -342,7 +351,7 @@ impl ArrowFlightSqlService for FlightSqlService {
342351
343352 let flight_descriptor = request. into_inner ( ) ;
344353
345- let dataset_schema = get_schema_for_plan ( & plan) ;
354+ let dataset_schema = get_schema_for_plan ( & plan, self . config . schema_with_metadata ) ;
346355
347356 // Form the response ticket (that the client will pass back to DoGet)
348357 let ticket = CommandTicket :: new ( sql:: Command :: CommandStatementSubstraitPlan ( query) )
@@ -381,7 +390,7 @@ impl ArrowFlightSqlService for FlightSqlService {
381390 . await
382391 . map_err ( df_error_to_status) ?;
383392
384- let dataset_schema = get_schema_for_plan ( & plan) ;
393+ let dataset_schema = get_schema_for_plan ( & plan, self . config . schema_with_metadata ) ;
385394
386395 // Form the response ticket (that the client will pass back to DoGet)
387396 let ticket = CommandTicket :: new ( sql:: Command :: CommandPreparedStatementQuery ( cmd) )
@@ -881,7 +890,7 @@ impl ArrowFlightSqlService for FlightSqlService {
881890 . await
882891 . map_err ( df_error_to_status) ?;
883892
884- let dataset_schema = get_schema_for_plan ( & plan) ;
893+ let dataset_schema = get_schema_for_plan ( & plan, self . config . schema_with_metadata ) ;
885894 let parameter_schema = parameter_schema_for_plan ( & plan) . map_err ( |e| e. as_ref ( ) . clone ( ) ) ?;
886895
887896 let dataset_schema =
@@ -1017,9 +1026,33 @@ fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
10171026}
10181027
10191028/// Return the schema for the specified logical plan
1020- fn get_schema_for_plan ( logical_plan : & LogicalPlan ) -> SchemaRef {
1021- // gather real schema, but only
1022- let schema = Schema :: from ( logical_plan. schema ( ) . as_ref ( ) ) . into ( ) ;
1029+ fn get_schema_for_plan ( logical_plan : & LogicalPlan , with_metadata : bool ) -> SchemaRef {
1030+ let schema: SchemaRef = if with_metadata {
1031+ // Get the DFSchema which contains table qualifiers
1032+ let df_schema = logical_plan. schema ( ) ;
1033+
1034+ // Convert to Arrow Schema and add table name metadata to fields
1035+ let fields_with_metadata: Vec < _ > = df_schema
1036+ . iter ( )
1037+ . map ( |( qualifier, field) | {
1038+ // If there's a table qualifier, add it as metadata
1039+ if let Some ( table_ref) = qualifier {
1040+ let mut metadata = field. metadata ( ) . clone ( ) ;
1041+ metadata. insert ( "table_name" . to_string ( ) , table_ref. to_string ( ) ) ;
1042+ field. as_ref ( ) . clone ( ) . with_metadata ( metadata)
1043+ } else {
1044+ field. as_ref ( ) . clone ( )
1045+ }
1046+ } )
1047+ . collect ( ) ;
1048+
1049+ Arc :: new ( Schema :: new_with_metadata (
1050+ fields_with_metadata,
1051+ df_schema. as_ref ( ) . metadata ( ) . clone ( ) ,
1052+ ) )
1053+ } else {
1054+ Arc :: new ( Schema :: from ( logical_plan. schema ( ) . as_ref ( ) ) )
1055+ } ;
10231056
10241057 // Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
10251058 // This is necessary as the schema can change based on dictionary hydration behavior.
0 commit comments