Skip to content

Commit 5901b8d

Browse files
authored
Feat/flight sql bearer auth example (#22)
Add Bearer Authentication example for Flight SQL server
1 parent 4b6dcea commit 5901b8d

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

datafusion-flight-sql-server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ once_cell = "1.21"
2727
prost = "0.13"
2828
tonic.workspace = true
2929
async-trait.workspace = true
30+
tonic-async-interceptor = "0.12.0"
3031

3132
[dev-dependencies]
3233
tokio.workspace = true
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
//! # DataFusion Flight SQL Server with Bearer Token Authentication Example
2+
//!
3+
//! This example demonstrates how to integrate Bearer Token authentication into a
4+
//! DataFusion Flight SQL server.
5+
//!
6+
//! Key components:
7+
//! - `bearer_auth_interceptor`: A tonic interceptor that validates "Authorization: Bearer <token>"
8+
//! headers. For simplicity, it uses a hardcoded list of valid tokens ("token1", "token2").
9+
//! - `UserData`: A simple struct holding user information (e.g., user_id) extracted from
10+
//! a valid token and inserted into request extensions.
11+
//! - `MySessionStateProvider`: A custom `SessionStateProvider` that retrieves `UserData`
12+
//! from request extensions. This allows tailoring the `SessionState` for the request,
13+
//! although in this example, it primarily clones a base context after successful authentication.
14+
//! It also registers a "test" CSV table for querying.
15+
//!
16+
//! ## Running the Example
17+
//!
18+
//! The server will start on `0.0.0.0:50051`. The `main` function includes client code
19+
//! that attempts to connect and perform a `GetTables` Flight SQL action using:
20+
//! 1. A valid token ("token1") - Expected to succeed.
21+
//! 2. An invalid token ("invalidtoken") - Expected to fail authentication.
22+
//! 3. No token - Expected to fail authentication.
23+
//!
24+
//! Observe the server console output for messages from the interceptor and session provider,
25+
//! and the client output for the success/failure of each attempt.
26+
27+
use std::time::Duration;
28+
29+
use arrow_flight::flight_service_server::FlightServiceServer;
30+
use arrow_flight::sql::client::FlightSqlServiceClient;
31+
use arrow_flight::sql::CommandGetTables;
32+
use async_trait::async_trait;
33+
use datafusion::error::{DataFusionError, Result};
34+
use datafusion::execution::context::SessionState; // SessionState is not in prelude
35+
use datafusion::prelude::*; // Covers SessionContext, CsvReadOptions, etc.
36+
use datafusion_flight_sql_server::service::FlightSqlService;
37+
use datafusion_flight_sql_server::session::SessionStateProvider;
38+
use tokio::time::sleep;
39+
use tonic::transport::{Channel, Endpoint, Server};
40+
use tonic::{Request, Status};
41+
42+
// UserData struct remains the same
43+
#[derive(Clone, Debug)]
44+
pub struct UserData {
45+
pub user_id: u32,
46+
}
47+
48+
// bearer_auth_interceptor remains the same
49+
async fn bearer_auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
50+
let auth_header = req
51+
.metadata()
52+
.get("authorization")
53+
.ok_or_else(|| Status::unauthenticated("no authorization provided"))?;
54+
55+
let auth_str = auth_header
56+
.to_str()
57+
.map_err(|_| Status::unauthenticated("invalid authorization header encoding"))?;
58+
59+
if !auth_str.starts_with("Bearer ") {
60+
return Err(Status::unauthenticated(
61+
"invalid authorization header format",
62+
));
63+
}
64+
65+
let token = &auth_str["Bearer ".len()..];
66+
67+
let user_data = match token {
68+
"token1" => UserData { user_id: 1 },
69+
"token2" => UserData { user_id: 2 },
70+
_ => return Err(Status::unauthenticated("invalid token")),
71+
};
72+
73+
req.extensions_mut().insert(user_data);
74+
Ok(req)
75+
}
76+
77+
// Updated MySessionStateProvider
78+
// #[derive(Debug, Clone)] // Removed Default
79+
pub struct MySessionStateProvider {
80+
base_context: SessionContext,
81+
}
82+
83+
impl MySessionStateProvider {
84+
async fn try_new() -> Result<Self> {
85+
// datafusion::error::Result
86+
let ctx = SessionContext::new();
87+
// Construct path to test.csv relative to CARGO_MANIFEST_DIR of the datafusion-flight-sql-server crate
88+
let csv_path = concat!(env!("CARGO_MANIFEST_DIR"), "/examples/test.csv");
89+
ctx.register_csv("test", csv_path, CsvReadOptions::new())
90+
.await?;
91+
Ok(Self { base_context: ctx })
92+
}
93+
}
94+
95+
#[async_trait]
96+
impl SessionStateProvider for MySessionStateProvider {
97+
async fn new_context(&self, request: &Request<()>) -> Result<SessionState, Status> {
98+
// tonic::Result for Status
99+
if let Some(user_data) = request.extensions().get::<UserData>() {
100+
println!(
101+
"Session context for user_id: {}. Cloning base context.",
102+
user_data.user_id
103+
);
104+
let state = self.base_context.state().clone();
105+
// Optional: Customize state based on user_data
106+
// state.set_config_option("datafusion.user_id", &user_data.user_id.to_string()).map_err(|e| Status::internal(format!("Failed to set config: {}",e)))?;
107+
Ok(state)
108+
} else {
109+
Err(Status::unauthenticated(
110+
"User data not found in request extensions (MySessionStateProvider)",
111+
))
112+
}
113+
}
114+
}
115+
116+
// Updated new_client_with_auth function
117+
async fn new_client_with_auth(
118+
dsn: String,
119+
token: Option<String>,
120+
) -> Result<FlightSqlServiceClient<Channel>> {
121+
// datafusion::error::Result
122+
let endpoint = Endpoint::from_shared(dsn.clone())
123+
.map_err(|e| DataFusionError::External(format!("Invalid DSN {}: {}", dsn, e).into()))?
124+
.connect_timeout(std::time::Duration::from_secs(10));
125+
126+
let channel = endpoint.connect().await.map_err(|e| {
127+
DataFusionError::External(format!("Failed to connect to {}: {}", dsn, e).into())
128+
})?;
129+
130+
let mut service_client = FlightSqlServiceClient::new(channel);
131+
if let Some(token_str) = token.clone() {
132+
service_client.set_header("authorization", format!("Bearer {}", token_str));
133+
}
134+
Ok(service_client)
135+
}
136+
137+
#[tokio::main]
138+
async fn main() -> Result<()> {
139+
// datafusion::error::Result<()>
140+
// Server Setup
141+
let dsn: String = "0.0.0.0:50051".to_string();
142+
let state_provider = Box::new(MySessionStateProvider::try_new().await?);
143+
let base_service = FlightSqlService::new_with_provider(state_provider);
144+
let svc: FlightServiceServer<FlightSqlService> = FlightServiceServer::new(base_service);
145+
let addr: std::net::SocketAddr = dsn.parse().map_err(|e| {
146+
DataFusionError::External(format!("Invalid address format {}: {}", dsn, e).into())
147+
})?;
148+
149+
tokio::spawn(async move {
150+
println!(
151+
"Bearer Authentication Flight SQL server listening on {}",
152+
addr
153+
);
154+
if let Err(e) = Server::builder()
155+
.layer(tonic_async_interceptor::async_interceptor(
156+
bearer_auth_interceptor,
157+
))
158+
.add_service(svc)
159+
.serve(addr)
160+
.await
161+
{
162+
eprintln!("Server error: {}", e);
163+
}
164+
});
165+
166+
// Wait for server to run
167+
sleep(Duration::from_secs(3)).await;
168+
169+
// Client Setup and Testing
170+
let client_dsn = "http://localhost:50051".to_string();
171+
172+
// Test Case 1: Valid Token
173+
println!("\nAttempting GetTables with valid token (token1)...");
174+
match new_client_with_auth(client_dsn.clone(), Some("token1".to_string())).await {
175+
Ok(mut client) => {
176+
let request = CommandGetTables {
177+
catalog: None,
178+
db_schema_filter_pattern: None,
179+
table_name_filter_pattern: None,
180+
table_types: vec![],
181+
include_schema: false,
182+
};
183+
match client.get_tables(request).await {
184+
Ok(response) => {
185+
println!("GetTables with token1 SUCCEEDED. Response: {:?}", response)
186+
}
187+
Err(e) => eprintln!("GetTables with token1 FAILED: {}", e),
188+
}
189+
}
190+
Err(e) => eprintln!("Failed to create client with token1: {}", e),
191+
}
192+
193+
// Test Case 2: Invalid Token
194+
println!("\nAttempting GetTables with invalid token (invalidtoken)...");
195+
match new_client_with_auth(client_dsn.clone(), Some("invalidtoken".to_string())).await {
196+
Ok(mut client) => {
197+
let request = CommandGetTables {
198+
catalog: None,
199+
db_schema_filter_pattern: None,
200+
table_name_filter_pattern: None,
201+
table_types: vec![],
202+
include_schema: false,
203+
};
204+
match client.get_tables(request).await {
205+
Ok(response) => println!(
206+
"GetTables with invalidtoken SUCCEEDED (unexpected). Response: {:?}",
207+
response
208+
),
209+
Err(e) => eprintln!("GetTables with invalidtoken FAILED (as expected): {:?}", e),
210+
}
211+
}
212+
Err(e) => eprintln!("Failed to create client with invalidtoken: {}", e),
213+
}
214+
215+
// Test Case 3: No Token
216+
println!("\nAttempting GetTables with no token...");
217+
match new_client_with_auth(client_dsn.clone(), None).await {
218+
Ok(mut client) => {
219+
let request = CommandGetTables {
220+
catalog: None,
221+
db_schema_filter_pattern: None,
222+
table_name_filter_pattern: None,
223+
table_types: vec![],
224+
include_schema: false,
225+
};
226+
match client.get_tables(request).await {
227+
Ok(response) => println!(
228+
"GetTables with no token SUCCEEDED (unexpected). Response: {:?}",
229+
response
230+
),
231+
Err(e) => eprintln!("GetTables with no token FAILED (as expected): {:?}", e),
232+
}
233+
}
234+
Err(e) => eprintln!("Failed to create client with no token: {}", e),
235+
}
236+
237+
Ok(())
238+
}

0 commit comments

Comments
 (0)