Skip to content

Commit 8ecfcb9

Browse files
committed
Add optional table name metadata in schema fields
This commit includes the following changes: - Add FlightSqlServiceConfig with schema_with_metadata option - Modify get_schema_for_plan() to add optional "table_name" metadata in arrow field
1 parent 374872e commit 8ecfcb9

File tree

4 files changed

+237
-7
lines changed

4 files changed

+237
-7
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#[derive(Default)]
2+
pub struct FlightSqlServiceConfig {
3+
/// When true, includes table names in field metadata under the "table_name" key.
4+
/// This allows clients to identify the source table or alias for each column in query results.
5+
pub schema_with_metadata: bool,
6+
}
7+
8+
impl FlightSqlServiceConfig {
9+
pub fn new() -> Self {
10+
Self {
11+
..Default::default()
12+
}
13+
}
14+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod config;
12
pub mod service;
23
pub mod session;
34
pub mod state;

datafusion-flight-sql-server/src/service.rs

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use datafusion::{
4848
use datafusion_substrait::{
4949
logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
5050
};
51+
5152
use futures::{Stream, StreamExt, TryStreamExt};
5253
use log::info;
5354
use once_cell::sync::Lazy;
@@ -56,6 +57,7 @@ use prost::Message;
5657
use tonic::transport::Server;
5758
use tonic::{Request, Response, Status, Streaming};
5859

60+
use super::config::FlightSqlServiceConfig;
5961
use super::session::{SessionStateProvider, StaticSessionStateProvider};
6062
use super::state::{CommandTicket, QueryHandle};
6163

@@ -65,6 +67,7 @@ type Result<T, E = Status> = std::result::Result<T, E>;
6567
pub struct FlightSqlService {
6668
provider: Box<dyn SessionStateProvider>,
6769
sql_options: Option<SQLOptions>,
70+
config: FlightSqlServiceConfig,
6871
}
6972

7073
impl FlightSqlService {
@@ -78,6 +81,15 @@ impl FlightSqlService {
7881
Self {
7982
provider,
8083
sql_options: None,
84+
config: FlightSqlServiceConfig::default(),
85+
}
86+
}
87+
88+
/// Replaces the FlightSqlServiceConfig with the provided config.
89+
pub fn with_config(self, config: FlightSqlServiceConfig) -> Self {
90+
Self {
91+
config,
92+
..self
8193
}
8294
}
8395

@@ -303,7 +315,7 @@ impl ArrowFlightSqlService for FlightSqlService {
303315
.await
304316
.map_err(df_error_to_status)?;
305317

306-
let dataset_schema = get_schema_for_plan(&plan);
318+
let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
307319

308320
// Form the response ticket (that the client will pass back to DoGet)
309321
let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
@@ -342,7 +354,7 @@ impl ArrowFlightSqlService for FlightSqlService {
342354

343355
let flight_descriptor = request.into_inner();
344356

345-
let dataset_schema = get_schema_for_plan(&plan);
357+
let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
346358

347359
// Form the response ticket (that the client will pass back to DoGet)
348360
let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
@@ -381,7 +393,7 @@ impl ArrowFlightSqlService for FlightSqlService {
381393
.await
382394
.map_err(df_error_to_status)?;
383395

384-
let dataset_schema = get_schema_for_plan(&plan);
396+
let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
385397

386398
// Form the response ticket (that the client will pass back to DoGet)
387399
let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
@@ -881,7 +893,7 @@ impl ArrowFlightSqlService for FlightSqlService {
881893
.await
882894
.map_err(df_error_to_status)?;
883895

884-
let dataset_schema = get_schema_for_plan(&plan);
896+
let dataset_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
885897
let parameter_schema = parameter_schema_for_plan(&plan).map_err(|e| e.as_ref().clone())?;
886898

887899
let dataset_schema =
@@ -1017,9 +1029,33 @@ fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
10171029
}
10181030

10191031
/// Return the schema for the specified logical plan
1020-
fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef {
1021-
// gather real schema, but only
1022-
let schema = Schema::from(logical_plan.schema().as_ref()).into();
1032+
fn get_schema_for_plan(logical_plan: &LogicalPlan, with_metadata: bool) -> SchemaRef {
1033+
let schema: SchemaRef = if with_metadata {
1034+
// Get the DFSchema which contains table qualifiers
1035+
let df_schema = logical_plan.schema();
1036+
1037+
// Convert to Arrow Schema and add table name metadata to fields
1038+
let fields_with_metadata: Vec<_> = df_schema
1039+
.iter()
1040+
.map(|(qualifier, field)| {
1041+
// If there's a table qualifier, add it as metadata
1042+
if let Some(table_ref) = qualifier {
1043+
let mut metadata = field.metadata().clone();
1044+
metadata.insert("table_name".to_string(), table_ref.to_string());
1045+
field.as_ref().clone().with_metadata(metadata)
1046+
} else {
1047+
field.as_ref().clone()
1048+
}
1049+
})
1050+
.collect();
1051+
1052+
Arc::new(Schema::new_with_metadata(
1053+
fields_with_metadata,
1054+
df_schema.as_ref().metadata().clone(),
1055+
))
1056+
} else {
1057+
Arc::new(Schema::from(logical_plan.schema().as_ref()))
1058+
};
10231059

10241060
// Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
10251061
// This is necessary as the schema can change based on dictionary hydration behavior.
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
use std::sync::Arc;
2+
3+
use arrow::{
4+
array::{Int32Array, RecordBatch, StringArray},
5+
datatypes::{DataType, Field, Schema},
6+
};
7+
use arrow_flight::sql::client::FlightSqlServiceClient;
8+
use datafusion::{
9+
datasource::MemTable,
10+
execution::context::{SessionContext, SessionState},
11+
};
12+
use datafusion_flight_sql_server::{config::FlightSqlServiceConfig, service::FlightSqlService};
13+
use tokio::time::{sleep, Duration};
14+
use tonic::transport::{Channel, Endpoint};
15+
16+
fn create_test_session() -> SessionState {
17+
let ctx = SessionContext::new();
18+
let schema = Arc::new(Schema::new(vec![
19+
Field::new("id", DataType::Int32, false),
20+
Field::new("name", DataType::Utf8, false),
21+
]));
22+
23+
let batch = RecordBatch::try_new(
24+
schema.clone(),
25+
vec![
26+
Arc::new(Int32Array::from(vec![1, 2, 3])),
27+
Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
28+
],
29+
)
30+
.unwrap();
31+
32+
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
33+
ctx.register_table("users", Arc::new(table)).unwrap();
34+
35+
let orders_schema = Arc::new(Schema::new(vec![
36+
Field::new("order_id", DataType::Int32, false),
37+
Field::new("user_id", DataType::Int32, false),
38+
Field::new("amount", DataType::Int32, false),
39+
]));
40+
41+
let orders_batch = RecordBatch::try_new(
42+
orders_schema.clone(),
43+
vec![
44+
Arc::new(Int32Array::from(vec![100, 101, 102, 103])),
45+
Arc::new(Int32Array::from(vec![1, 2, 1, 3])),
46+
Arc::new(Int32Array::from(vec![50, 75, 100, 25])),
47+
],
48+
)
49+
.unwrap();
50+
51+
let orders_table = MemTable::try_new(orders_schema, vec![vec![orders_batch]]).unwrap();
52+
ctx.register_table("orders", Arc::new(orders_table))
53+
.unwrap();
54+
55+
ctx.state()
56+
}
57+
58+
async fn start_test_server(addr: String, state: SessionState) {
59+
let config = FlightSqlServiceConfig {
60+
schema_with_metadata: true,
61+
};
62+
63+
let service = FlightSqlService::new(state).with_config(config);
64+
65+
tokio::spawn(async move {
66+
service
67+
.serve(addr)
68+
.await
69+
.expect("Server should start successfully")
70+
});
71+
sleep(Duration::from_millis(500)).await;
72+
}
73+
74+
async fn create_test_client(addr: &str) -> FlightSqlServiceClient<Channel> {
75+
let endpoint = Endpoint::new(addr.to_string()).expect("Valid endpoint");
76+
let channel = endpoint.connect().await.expect("Connection successful");
77+
FlightSqlServiceClient::new(channel)
78+
}
79+
80+
#[tokio::test]
81+
async fn test_schema_contains_table_name_metadata() {
82+
let addr = "0.0.0.0:50071";
83+
let state = create_test_session();
84+
start_test_server(addr.to_string(), state).await;
85+
86+
let mut client = create_test_client(&format!("http://{}", addr)).await;
87+
88+
let flight_info = client
89+
.execute("SELECT id, name FROM users".to_string(), None)
90+
.await
91+
.expect("Query should succeed");
92+
93+
let schema = flight_info
94+
.try_decode_schema()
95+
.expect("Should decode schema");
96+
97+
for field in schema.fields() {
98+
assert!(
99+
field.metadata().contains_key("table_name"),
100+
"Field {} should have table_name metadata",
101+
field.name()
102+
);
103+
104+
assert_eq!(
105+
field.metadata().get("table_name").unwrap(),
106+
"users",
107+
"Field {} should have table_name='users'",
108+
field.name()
109+
);
110+
}
111+
}
112+
113+
#[tokio::test]
114+
async fn test_schema_metadata_with_subquery_and_join() {
115+
let addr = "0.0.0.0:50072";
116+
let state = create_test_session();
117+
start_test_server(addr.to_string(), state).await;
118+
119+
let mut client = create_test_client(&format!("http://{}", addr)).await;
120+
121+
let query = r#"
122+
SELECT u.id, u.name, o.amount
123+
FROM users u
124+
JOIN (SELECT * FROM orders WHERE AMOUNT > 25 ) o
125+
ON u.id = o.user_id
126+
"#;
127+
128+
let flight_info = client
129+
.execute(query.to_string(), None)
130+
.await
131+
.expect("Query with subquery and join should succeed");
132+
133+
let schema = flight_info
134+
.try_decode_schema()
135+
.expect("Should decode schema");
136+
137+
assert_eq!(schema.fields().len(), 3, "Should have 3 fields");
138+
139+
// Both fields should have table_name metadata pointing to 'u'
140+
let id_field = schema.field(0);
141+
let name_field = schema.field(1);
142+
143+
// This field should have table_name metadata pointing to 'o'
144+
let amount_field = schema.field(2);
145+
146+
assert_eq!(id_field.name(), "id");
147+
assert_eq!(name_field.name(), "name");
148+
assert_eq!(amount_field.name(), "amount");
149+
150+
assert!(
151+
id_field.metadata().contains_key("table_name"),
152+
"id field should have table_name metadata"
153+
);
154+
assert_eq!(
155+
id_field.metadata().get("table_name").unwrap(),
156+
"u",
157+
"id field should have table_name='u'"
158+
);
159+
160+
assert!(
161+
name_field.metadata().contains_key("table_name"),
162+
"name field should have table_name metadata"
163+
);
164+
assert_eq!(
165+
name_field.metadata().get("table_name").unwrap(),
166+
"u",
167+
"name field should have table_name='u'"
168+
);
169+
170+
assert!(
171+
amount_field.metadata().contains_key("table_name"),
172+
"amount field should have table_name metadata"
173+
);
174+
assert_eq!(
175+
amount_field.metadata().get("table_name").unwrap(),
176+
"o",
177+
"amount field should have table_name='o'"
178+
);
179+
}

0 commit comments

Comments
 (0)