Skip to content

Commit c66a5b5

Browse files
feat: Add Bearer Authentication example for Flight SQL server
This commit introduces a new example, `bearer_auth_flight_sql.rs`, demonstrating how to implement Bearer Token authentication in a DataFusion Flight SQL server. The example includes: - A tonic interceptor (`bearer_auth_interceptor`) for validating Bearer tokens from the "authorization" header. - A `UserData` struct to store information extracted from valid tokens, which is then added to request extensions. - A custom `SessionStateProvider` (`MySessionStateProvider`) that retrieves `UserData` from request extensions to create or customize the `SessionState`. This provider also registers a sample CSV table ("test") for querying. - Server-side setup to integrate the interceptor and custom session provider. - Client-side test cases in the `main` function that attempt `GetTables` Flight SQL calls with: - A valid token (expected to succeed). - An invalid token (expected to fail). - No token (expected to fail). - Module-level documentation explaining the example's functionality and components. This example serves as a guide for you if you are looking to implement request-level authentication and session customization in your Flight SQL services.
1 parent 4b6dcea commit c66a5b5

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
//! # DataFusion Flight SQL Server with Bearer Token Authentication Example
2+
//!
3+
//! This example demonstrates how to integrate Bearer Token authentication into a
4+
//! DataFusion Flight SQL server.
5+
//!
6+
//! Key components:
7+
//! - `bearer_auth_interceptor`: A tonic interceptor that validates "Authorization: Bearer <token>"
8+
//! headers. For simplicity, it uses a hardcoded list of valid tokens ("token1", "token2").
9+
//! - `UserData`: A simple struct holding user information (e.g., user_id) extracted from
10+
//! a valid token and inserted into request extensions.
11+
//! - `MySessionStateProvider`: A custom `SessionStateProvider` that retrieves `UserData`
12+
//! from request extensions. This allows tailoring the `SessionState` for the request,
13+
//! although in this example, it primarily clones a base context after successful authentication.
14+
//! It also registers a "test" CSV table for querying.
15+
//!
16+
//! ## Running the Example
17+
//!
18+
//! The server will start on `0.0.0.0:50051`. The `main` function includes client code
19+
//! that attempts to connect and perform a `GetTables` Flight SQL action using:
20+
//! 1. A valid token ("token1") - Expected to succeed.
21+
//! 2. An invalid token ("invalidtoken") - Expected to fail authentication.
22+
//! 3. No token - Expected to fail authentication.
23+
//!
24+
//! Observe the server console output for messages from the interceptor and session provider,
25+
//! and the client output for the success/failure of each attempt.
26+
27+
use std::sync::Arc;
28+
use std::time::Duration;
29+
30+
use arrow_flight::sql::client::FlightSqlServiceClient;
31+
use arrow_flight::sql::CommandGetTables;
32+
use async_trait::async_trait;
33+
use datafusion::error::{DataFusionError, Result};
34+
use datafusion::prelude::*; // Covers SessionContext, CsvReadOptions, etc.
35+
use datafusion::execution::context::{SessionConfig, SessionState}; // SessionState is not in prelude
36+
use datafusion::execution::runtime_env::RuntimeEnv;
37+
use datafusion_flight_sql_server::service::FlightSqlService;
38+
use datafusion_flight_sql_server::session::SessionStateProvider;
39+
use tokio::time::sleep;
40+
use tonic::transport::{Channel, Endpoint, Server};
41+
use tonic::{metadata::MetadataValue, Request, Status};
42+
43+
// UserData struct remains the same
44+
#[derive(Clone, Debug)]
45+
pub struct UserData {
46+
pub user_id: u32,
47+
}
48+
49+
// bearer_auth_interceptor remains the same
50+
async fn bearer_auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
51+
let auth_header = req
52+
.metadata()
53+
.get("authorization")
54+
.ok_or_else(|| Status::unauthenticated("no authorization provided"))?;
55+
56+
let auth_str = auth_header
57+
.to_str()
58+
.map_err(|_| Status::unauthenticated("invalid authorization header encoding"))?;
59+
60+
if !auth_str.starts_with("Bearer ") {
61+
return Err(Status::unauthenticated(
62+
"invalid authorization header format",
63+
));
64+
}
65+
66+
let token = &auth_str["Bearer ".len()..];
67+
68+
let user_data = match token {
69+
"token1" => UserData { user_id: 1 },
70+
"token2" => UserData { user_id: 2 },
71+
_ => return Err(Status::unauthenticated("invalid token")),
72+
};
73+
74+
req.extensions_mut().insert(user_data);
75+
Ok(req)
76+
}
77+
78+
// Updated MySessionStateProvider
79+
#[derive(Debug, Clone)] // Removed Default
80+
pub struct MySessionStateProvider {
81+
base_context: SessionContext,
82+
}
83+
84+
impl MySessionStateProvider {
85+
async fn try_new() -> Result<Self> { // datafusion::error::Result
86+
let ctx = SessionContext::new();
87+
// Construct path to test.csv relative to CARGO_MANIFEST_DIR of the datafusion-flight-sql-server crate
88+
let csv_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/test.csv");
89+
ctx.register_csv("test", csv_path, CsvReadOptions::new()).await?;
90+
Ok(Self { base_context: ctx })
91+
}
92+
}
93+
94+
#[async_trait]
95+
impl SessionStateProvider for MySessionStateProvider {
96+
async fn new_context(&self, request: &Request<()>) -> Result<SessionState, Status> { // tonic::Result for Status
97+
if let Some(user_data) = request.extensions().get::<UserData>() {
98+
println!(
99+
"Session context for user_id: {}. Cloning base context.",
100+
user_data.user_id
101+
);
102+
let state = self.base_context.state().clone();
103+
// Optional: Customize state based on user_data
104+
// state.set_config_option("datafusion.user_id", &user_data.user_id.to_string()).map_err(|e| Status::internal(format!("Failed to set config: {}",e)))?;
105+
Ok(state)
106+
} else {
107+
Err(Status::unauthenticated(
108+
"User data not found in request extensions (MySessionStateProvider)",
109+
))
110+
}
111+
}
112+
}
113+
114+
// Updated new_client_with_auth function
115+
async fn new_client_with_auth(
116+
dsn: String,
117+
token: Option<String>,
118+
) -> Result<FlightSqlServiceClient<Channel>> { // datafusion::error::Result
119+
let endpoint = Endpoint::from_shared(dsn.clone())
120+
.map_err(|e| DataFusionError::External(format!("Invalid DSN {}: {}", dsn, e).into()))?
121+
.connect_timeout(std::time::Duration::from_secs(10));
122+
123+
let channel = endpoint.connect().await
124+
.map_err(|e| DataFusionError::External(format!("Failed to connect to {}: {}", dsn, e).into()))?;
125+
126+
let service_client = FlightSqlServiceClient::with_interceptor(
127+
channel,
128+
move |mut req: Request<()>| {
129+
if let Some(token_str) = token.clone() {
130+
let bearer_token = format!("Bearer {}", token_str);
131+
match MetadataValue::try_from(&bearer_token) {
132+
Ok(metadata_val) => req.metadata_mut().insert("authorization", metadata_val),
133+
Err(_) => return Err(Status::invalid_argument("Invalid token format for metadata")),
134+
};
135+
}
136+
Ok(req)
137+
},
138+
);
139+
Ok(service_client)
140+
}
141+
142+
fn status_to_df_error(err: tonic::Status) -> DataFusionError {
143+
DataFusionError::External(format!("Tonic status error: {}", err).into())
144+
}
145+
146+
#[tokio::main]
147+
async fn main() -> Result<()> { // datafusion::error::Result<()>
148+
// Server Setup
149+
let dsn: String = "0.0.0.0:50051".to_string();
150+
let state_provider = Arc::new(MySessionStateProvider::try_new().await?);
151+
let base_service = FlightSqlService::new_with_state_provider(state_provider).into_service();
152+
let wrapped_service = tonic::service::interceptor_fn(base_service, bearer_auth_interceptor);
153+
let addr: std::net::SocketAddr = dsn.parse().map_err(|e| DataFusionError::External(format!("Invalid address format {}: {}", dsn, e).into()))?;
154+
155+
tokio::spawn(async move {
156+
println!("Bearer Authentication Flight SQL server listening on {}", addr);
157+
if let Err(e) = Server::builder().add_service(wrapped_service).serve(addr).await {
158+
eprintln!("Server error: {}", e);
159+
}
160+
});
161+
162+
// Wait for server to run
163+
sleep(Duration::from_secs(3)).await;
164+
165+
// Client Setup and Testing
166+
let client_dsn = "http://localhost:50051".to_string();
167+
168+
// Test Case 1: Valid Token
169+
println!("\nAttempting GetTables with valid token (token1)...");
170+
match new_client_with_auth(client_dsn.clone(), Some("token1".to_string())).await {
171+
Ok(mut client) => {
172+
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
173+
match client.get_tables(request).await {
174+
Ok(response) => println!("GetTables with token1 SUCCEEDED. Response: {:?}", response.into_inner()),
175+
Err(e) => eprintln!("GetTables with token1 FAILED: {}", status_to_df_error(e)),
176+
}
177+
}
178+
Err(e) => eprintln!("Failed to create client with token1: {}", e),
179+
}
180+
181+
// Test Case 2: Invalid Token
182+
println!("\nAttempting GetTables with invalid token (invalidtoken)...");
183+
match new_client_with_auth(client_dsn.clone(), Some("invalidtoken".to_string())).await {
184+
Ok(mut client) => {
185+
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
186+
match client.get_tables(request).await {
187+
Ok(response) => println!("GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}", response.into_inner()),
188+
Err(e) => eprintln!("GetTables with invalidtoken FAILED (as expected): {}", status_to_df_error(e)),
189+
}
190+
}
191+
Err(e) => eprintln!("Failed to create client with invalidtoken: {}", e),
192+
}
193+
194+
// Test Case 3: No Token
195+
println!("\nAttempting GetTables with no token...");
196+
match new_client_with_auth(client_dsn.clone(), None).await {
197+
Ok(mut client) => {
198+
let request = CommandGetTables { catalog: None, db_schema_filter_pattern: None, table_name_filter_pattern: None, table_types: vec![], include_schema: false };
199+
match client.get_tables(request).await {
200+
Ok(response) => println!("GetTables with no token SUCCEEDED (unexpected). Response: {:?}", response.into_inner()),
201+
Err(e) => eprintln!("GetTables with no token FAILED (as expected): {}", status_to_df_error(e)),
202+
}
203+
}
204+
Err(e) => eprintln!("Failed to create client with no token: {}", e),
205+
}
206+
207+
Ok(())
208+
}

0 commit comments

Comments
 (0)