Skip to content

Commit e790432

Browse files
Start adding flightsql command (#325)
1 parent 20c01df commit e790432

File tree

8 files changed

+240
-25
lines changed

8 files changed

+240
-25
lines changed

crates/datafusion-app/src/flightsql.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow_flight::sql::client::FlightSqlServiceClient;
20+
use arrow_flight::{
21+
decode::FlightRecordBatchStream, sql::client::FlightSqlServiceClient, FlightInfo,
22+
};
2123
#[cfg(feature = "flightsql")]
2224
use base64::engine::{general_purpose::STANDARD, Engine as _};
2325
use datafusion::{
2426
error::{DataFusionError, Result as DFResult},
2527
physical_plan::stream::RecordBatchStreamAdapter,
2628
sql::parser::DFParser,
2729
};
28-
use log::{error, info, warn};
30+
use log::{debug, error, info, warn};
2931

3032
use color_eyre::eyre::{self, Result};
3133
use tokio::sync::Mutex;
@@ -203,4 +205,43 @@ impl FlightSQLContext {
203205
return Err(DataFusionError::External("Missing client".into()));
204206
}
205207
}
208+
209+
pub async fn get_catalogs_flight_info(&self) -> DFResult<FlightInfo> {
210+
let client = self.client.clone();
211+
let mut guard = client.lock().await;
212+
if let Some(client) = guard.as_mut() {
213+
client
214+
.get_catalogs()
215+
.await
216+
.map_err(|e| DataFusionError::ArrowError(e, None))
217+
} else {
218+
Err(DataFusionError::External(
219+
"No FlightSQL client configured. Add one in `~/.config/dft/config.toml`".into(),
220+
))
221+
}
222+
}
223+
224+
pub async fn do_get(&self, flight_info: FlightInfo) -> DFResult<Vec<FlightRecordBatchStream>> {
225+
let client = self.client.clone();
226+
let mut guard = client.lock().await;
227+
if let Some(client) = guard.as_mut() {
228+
let mut streams = Vec::new();
229+
for endpoint in flight_info.endpoint {
230+
if let Some(ticket) = endpoint.ticket {
231+
let stream = client
232+
.do_get(ticket.into_request())
233+
.await
234+
.map_err(|e| DataFusionError::ArrowError(e, None))?;
235+
streams.push(stream);
236+
} else {
237+
debug!("No ticket for endpoint: {endpoint}");
238+
}
239+
}
240+
Ok(streams)
241+
} else {
242+
Err(DataFusionError::External(
243+
"No FlightSQL client configured. Add one in `~/.config/dft/config.toml`".into(),
244+
))
245+
}
246+
}
206247
}

crates/datafusion-app/src/local.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use super::stats::{ExecutionDurationStats, ExecutionStats};
4343
#[cfg(feature = "udfs-wasm")]
4444
use super::wasm::create_wasm_udfs;
4545
#[cfg(feature = "observability")]
46-
use crate::observability::ObservabilityContext;
46+
use {crate::config::ObservabilityConfig, crate::observability::ObservabilityContext};
4747

4848
/// Structure for executing queries
4949
///
@@ -72,6 +72,24 @@ pub struct ExecutionContext {
7272
observability: ObservabilityContext,
7373
}
7474

75+
impl Default for ExecutionContext {
76+
fn default() -> Self {
77+
let cfg = SessionConfig::new().with_information_schema(true);
78+
let session_ctx = SessionContext::new_with_config(cfg);
79+
#[cfg(feature = "observability")]
80+
let observability =
81+
ObservabilityContext::try_new(ObservabilityConfig::default(), "test").unwrap();
82+
Self {
83+
config: ExecutionConfig::default(),
84+
session_ctx,
85+
ddl_path: None,
86+
executor: None,
87+
#[cfg(feature = "observability")]
88+
observability,
89+
}
90+
}
91+
}
92+
7593
impl std::fmt::Debug for ExecutionContext {
7694
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7795
f.debug_struct("ExecutionContext").finish()

src/args.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ impl DftArgs {
127127
}
128128
}
129129

130+
#[derive(Clone, Debug, Subcommand)]
131+
pub enum FlightSqlCommand {
132+
/// Executes `GetFlightInfo` and `DoGet` on the provided SQL query
133+
StatementQuery { sql: String },
134+
/// Executes `GetCatalogsFlightInfo` and `DoGet`
135+
GetCatalogs,
136+
}
137+
130138
#[derive(Clone, Debug, Subcommand)]
131139
pub enum Command {
132140
/// Start a HTTP server
@@ -139,6 +147,13 @@ pub enum Command {
139147
#[clap(long, help = "Set the port to be used for serving metrics")]
140148
metrics_addr: Option<SocketAddr>,
141149
},
150+
/// Make a request to a FlightSQL server
151+
#[cfg(feature = "flightsql")]
152+
#[command(name = "flightsql")]
153+
FlightSql {
154+
#[clap(subcommand)]
155+
command: FlightSqlCommand,
156+
},
142157
/// Start a FlightSQL server
143158
#[cfg(feature = "flightsql")]
144159
#[command(name = "serve-flightsql")]

src/cli/mod.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use std::io::Write;
3939
use std::path::{Path, PathBuf};
4040
#[cfg(feature = "flightsql")]
4141
use {
42+
crate::args::{Command, FlightSqlCommand},
4243
datafusion_app::{
4344
config::{AuthConfig, FlightSQLConfig},
4445
flightsql::FlightSQLContext,
@@ -82,6 +83,31 @@ impl CliApp {
8283
Ok(())
8384
}
8485

86+
#[cfg(feature = "flightsql")]
87+
async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> {
88+
use futures::stream;
89+
90+
match command {
91+
FlightSqlCommand::StatementQuery { sql } => self.exec_from_flightsql(sql, 0).await,
92+
FlightSqlCommand::GetCatalogs => {
93+
let flight_info = self
94+
.app_execution
95+
.flightsql_ctx()
96+
.get_catalogs_flight_info()
97+
.await?;
98+
let streams = self
99+
.app_execution
100+
.flightsql_ctx()
101+
.do_get(flight_info)
102+
.await?;
103+
let flight_batch_stream = stream::select_all(streams);
104+
self.print_any_stream(flight_batch_stream).await;
105+
106+
Ok(())
107+
}
108+
}
109+
}
110+
85111
/// Execute the provided sql, which was passed as an argument from CLI.
86112
///
87113
/// Optionally, use the FlightSQL client for execution.
@@ -92,6 +118,11 @@ impl CliApp {
92118

93119
self.validate_args()?;
94120

121+
#[cfg(feature = "flightsql")]
122+
if let Some(Command::FlightSql { command }) = &self.args.command {
123+
return self.handle_flightsql_command(command.clone()).await;
124+
};
125+
95126
#[cfg(not(feature = "flightsql"))]
96127
match (
97128
self.args.files.is_empty(),
@@ -603,7 +634,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
603634
let mut app_execution = AppExecution::new(execution_ctx);
604635
#[cfg(feature = "flightsql")]
605636
{
606-
if cli.flightsql {
637+
if cli.flightsql || matches!(cli.command, Some(Command::FlightSql { .. })) {
607638
let auth = AuthConfig {
608639
basic_auth: config.flightsql_client.auth.basic_auth,
609640
bearer_token: config.flightsql_client.auth.bearer_token,

src/main.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,16 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
6767
}
6868

6969
#[cfg(feature = "flightsql")]
70-
if let Some(Command::ServeFlightSql { .. }) = cli.command {
71-
server::flightsql::try_run(cli.clone(), cfg.clone()).await?;
72-
return Ok(());
70+
{
71+
if matches!(cli.command, Some(Command::FlightSql { .. })) {
72+
cli::try_run(cli, cfg).await?;
73+
return Ok(());
74+
} else if let Some(Command::ServeFlightSql { .. }) = cli.command {
75+
server::flightsql::try_run(cli.clone(), cfg.clone()).await?;
76+
return Ok(());
77+
}
7378
}
79+
7480
#[cfg(feature = "http")]
7581
{
7682
if let Some(Command::ServeHttp { .. }) = cli.command {

src/server/flightsql/service.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
2020
use arrow_flight::error::FlightError;
2121
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
2222
use arrow_flight::sql::server::FlightSqlService;
23-
use arrow_flight::sql::{Any, CommandStatementQuery, SqlInfo, TicketStatementQuery};
23+
use arrow_flight::sql::{
24+
Any, CommandGetCatalogs, CommandStatementQuery, SqlInfo, TicketStatementQuery,
25+
};
2426
use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket};
2527
use color_eyre::Result;
2628
use datafusion::logical_expr::LogicalPlan;
@@ -218,6 +220,31 @@ impl FlightSqlServiceImpl {
218220
impl FlightSqlService for FlightSqlServiceImpl {
219221
type FlightService = FlightSqlServiceImpl;
220222

223+
async fn get_flight_info_catalogs(
224+
&self,
225+
_query: CommandGetCatalogs,
226+
request: Request<FlightDescriptor>,
227+
) -> Result<Response<FlightInfo>, Status> {
228+
counter!("requests", "endpoint" => "get_flight_info").increment(1);
229+
let start = Timestamp::now();
230+
let request_id = uuid::Uuid::new_v4();
231+
let query = "SELECT DISTINCT table_catalog FROM information_schema.tables".to_string();
232+
let res = self
233+
.get_flight_info_statement_handler(query, request_id, request)
234+
.await;
235+
236+
// TODO: Move recording to after response is sent to not impact response latency
237+
self.record_request(
238+
start,
239+
Some(request_id.to_string()),
240+
res.as_ref().err(),
241+
"/get_flight_info_catalogs".to_string(),
242+
"get_flight_info_catalogs_latency_ms",
243+
)
244+
.await;
245+
res
246+
}
247+
221248
async fn get_flight_info_statement(
222249
&self,
223250
query: CommandStatementQuery,

tests/cli_cases/tpch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async fn test_custom_config_with_s3() {
7373
);
7474
let config = config_builder.build("my_config.toml");
7575

76-
let a = Command::cargo_bin("dft")
76+
Command::cargo_bin("dft")
7777
.unwrap()
7878
.arg("--config")
7979
.arg(config.path)

tests/extension_cases/flightsql.rs

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
use std::{io::Read, time::Duration};
1919

2020
use assert_cmd::Command;
21-
use datafusion_dft::test_utils::fixture::{TestFixture, TestFlightSqlServiceImpl};
21+
use datafusion_app::local::ExecutionContext;
22+
use datafusion_dft::{
23+
execution::AppExecution,
24+
server::flightsql::service::FlightSqlServiceImpl,
25+
test_utils::fixture::{TestFixture, TestFlightSqlServiceImpl},
26+
};
2227

2328
use crate::{
2429
cli_cases::{contains_str, sql_in_file},
@@ -603,24 +608,34 @@ async fn test_output_parquet() {
603608

604609
let cloned_path = path.clone();
605610

606-
let sql = "SELECT 1".to_string();
607-
Command::cargo_bin("dft")
608-
.unwrap()
609-
.arg("-c")
610-
.arg(sql.clone())
611-
.arg("-o")
612-
.arg(cloned_path)
613-
.assert()
614-
.success();
611+
tokio::task::spawn_blocking(|| {
612+
let sql = "SELECT 1".to_string();
613+
Command::cargo_bin("dft")
614+
.unwrap()
615+
.arg("-c")
616+
.arg(sql.clone())
617+
.arg("--flightsql")
618+
.arg("-o")
619+
.arg(cloned_path)
620+
.timeout(Duration::from_secs(5))
621+
.assert()
622+
.success();
623+
})
624+
.await
625+
.unwrap();
615626

616627
let read_sql = format!("SELECT * FROM '{}'", path.to_str().unwrap());
617628

618-
let assert = Command::cargo_bin("dft")
619-
.unwrap()
620-
.arg("-c")
621-
.arg(read_sql)
622-
.assert()
623-
.success();
629+
let assert = tokio::task::spawn_blocking(|| {
630+
Command::cargo_bin("dft")
631+
.unwrap()
632+
.arg("-c")
633+
.arg(read_sql)
634+
.assert()
635+
.success()
636+
})
637+
.await
638+
.unwrap();
624639

625640
let expected = r#"
626641
+----------+
@@ -633,3 +648,65 @@ async fn test_output_parquet() {
633648

634649
fixture.shutdown_and_wait().await;
635650
}
651+
652+
#[tokio::test]
653+
async fn test_flightsql_query_command() {
654+
let test_server = TestFlightSqlServiceImpl::new();
655+
let fixture = TestFixture::new(test_server.service(), "127.0.0.1:50051").await;
656+
657+
let assert = tokio::task::spawn_blocking(|| {
658+
let sql = "SELECT 1".to_string();
659+
Command::cargo_bin("dft")
660+
.unwrap()
661+
.arg("flightsql")
662+
.arg("statement-query")
663+
.arg(sql.clone())
664+
.timeout(Duration::from_secs(5))
665+
.assert()
666+
.success()
667+
})
668+
.await
669+
.unwrap();
670+
671+
let expected = r#"
672+
+----------+
673+
| Int64(1) |
674+
+----------+
675+
| 1 |
676+
+----------+"#;
677+
678+
assert.stdout(contains_str(expected));
679+
680+
fixture.shutdown_and_wait().await;
681+
}
682+
683+
#[tokio::test]
684+
async fn test_flightsql_get_catalogs() {
685+
let ctx = ExecutionContext::default();
686+
let exec = AppExecution::new(ctx);
687+
let test_server = FlightSqlServiceImpl::new(exec);
688+
let fixture = TestFixture::new(test_server.service(), "127.0.0.1:50051").await;
689+
690+
let assert = tokio::task::spawn_blocking(|| {
691+
Command::cargo_bin("dft")
692+
.unwrap()
693+
.arg("flightsql")
694+
.arg("get-catalogs")
695+
.timeout(Duration::from_secs(5))
696+
.assert()
697+
.success()
698+
})
699+
.await
700+
.unwrap();
701+
702+
let expected = r#"
703+
+---------------+
704+
| table_catalog |
705+
+---------------+
706+
| datafusion |
707+
+---------------+"#;
708+
709+
assert.stdout(contains_str(expected));
710+
711+
fixture.shutdown_and_wait().await;
712+
}

0 commit comments

Comments
 (0)