From 65e81bbe24ff200c295e444597ba1ec62f9a143f Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Mon, 12 May 2025 08:13:51 -0400 Subject: [PATCH 01/16] Start adding prepared statements --- crates/datafusion-app/src/flightsql.rs | 23 +++++++++- src/args.rs | 7 ++- src/cli/mod.rs | 4 +- src/server/flightsql/service.rs | 63 ++++++++++++++++++++++++-- src/server/http/mod.rs | 8 ++-- 5 files changed, 95 insertions(+), 10 deletions(-) diff --git a/crates/datafusion-app/src/flightsql.rs b/crates/datafusion-app/src/flightsql.rs index 95271df..d17de13 100644 --- a/crates/datafusion-app/src/flightsql.rs +++ b/crates/datafusion-app/src/flightsql.rs @@ -19,7 +19,10 @@ use std::sync::Arc; use arrow_flight::{ decode::FlightRecordBatchStream, - sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, + sql::{ + client::{FlightSqlServiceClient, PreparedStatement}, + ActionCreatePreparedStatementRequest, CommandGetDbSchemas, CommandGetTables, + }, FlightInfo, }; #[cfg(feature = "flightsql")] @@ -275,6 +278,24 @@ impl FlightSQLContext { } } + pub async fn create_prepared_statement( + &self, + query: String, + ) -> DFResult> { + let client = Arc::clone(&self.client); + let mut guard = client.lock().await; + if let Some(client) = guard.as_mut() { + client + .prepare(query, None) + .await + .map_err(|e| DataFusionError::ArrowError(e, None)) + } else { + Err(DataFusionError::External( + "No FlightSQL client configured. Add one in `~/.config/dft/config.toml`".into(), + )) + } + } + pub async fn do_get(&self, flight_info: FlightInfo) -> DFResult> { let client = Arc::clone(&self.client); let mut guard = client.lock().await; diff --git a/src/args.rs b/src/args.rs index 5a0775f..fd27dab 100644 --- a/src/args.rs +++ b/src/args.rs @@ -134,7 +134,7 @@ pub enum FlightSqlCommand { StatementQuery { /// The query to execute #[clap(long)] - sql: String, + query: String, }, /// Executes `CommandGetCatalogs` and `DoGet` to return results GetCatalogs, @@ -162,6 +162,11 @@ pub enum FlightSqlCommand { #[clap(long)] table_types: Option>, }, + CreatePreparedStatement { + /// The query for the prepared statement + #[clap(long)] + query: String, + }, } #[derive(Clone, Debug, Subcommand)] diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 01e549f..dc0369b 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -88,7 +88,7 @@ impl CliApp { use futures::stream; match command { - FlightSqlCommand::StatementQuery { sql } => self.exec_from_flightsql(sql, 0).await, + FlightSqlCommand::StatementQuery { query } => self.exec_from_flightsql(query, 0).await, FlightSqlCommand::GetCatalogs => { let flight_info = self .app_execution @@ -149,6 +149,8 @@ impl CliApp { self.print_any_stream(flight_batch_stream).await; Ok(()) } + + FlightSqlCommand::CreatePreparedStatement { query } => Ok(()), } } diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index 3e638fc..ce5437f 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -21,12 +21,17 @@ use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; use arrow_flight::sql::server::FlightSqlService; use arrow_flight::sql::{ - Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTables, CommandStatementQuery, SqlInfo, + ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, + CommandGetCatalogs, CommandGetDbSchemas, CommandGetTables, CommandStatementQuery, SqlInfo, TicketStatementQuery, }; -use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket}; +use arrow_flight::{ + Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket, +}; use color_eyre::Result; -use datafusion::logical_expr::LogicalPlan; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, Prepare}; use datafusion::prelude::{col, lit}; use datafusion::sql::parser::DFParser; use datafusion_app::local::ExecutionContext; @@ -205,6 +210,44 @@ impl FlightSqlServiceImpl { } } + async fn query_to_create_prepared_statement_action( + &self, + query: String, + ) -> Result { + let ctx = self.execution.session_ctx(); + let state = ctx.state(); + let logical_plan = state.create_logical_plan(&query).await?; + let query_schema = logical_plan.schema().as_arrow().clone(); + let (parameter_data_types, parameter_fields): (Vec, Vec) = logical_plan + .get_parameter_types()? + .into_iter() + .filter_map(|(k, v)| v.map(|v| (v.clone(), Field::new(k, v, false)))) + .collect(); + let parameters_schema = Schema::new(parameter_fields); + + let id = Uuid::new_v4().to_string(); + let builder = LogicalPlanBuilder::new(logical_plan); + let prepared = builder + .prepare(id.clone(), parameter_data_types) + .map_err(|e| Status::internal(e.to_string()))? + .build()?; + ctx.execute_logical_plan(prepared).await?.collect().await?; + debug!("created prepared statement with id {id}"); + let opts = IpcWriteOptions::default(); + let dataset_schema_as_ipc = SchemaAsIpc::new(&query_schema, &opts); + let IpcMessage(dataset_bytes) = IpcMessage::try_from(dataset_schema_as_ipc) + .map_err(|e| Status::internal(e.to_string()))?; + let parameters_schema_as_ipc = SchemaAsIpc::new(¶meters_schema, &opts); + let IpcMessage(parameters_bytes) = IpcMessage::try_from(parameters_schema_as_ipc)?; + debug!("serialized prepared statement"); + let res = ActionCreatePreparedStatementResult { + prepared_statement_handle: id.into_bytes().into(), + dataset_schema: dataset_bytes, + parameter_schema: parameters_bytes, + }; + Ok(res) + } + async fn do_get_statement_handler( &self, request_id: String, @@ -378,6 +421,20 @@ impl FlightSqlService for FlightSqlServiceImpl { res } + async fn do_action_create_prepared_statement( + &self, + query: ActionCreatePreparedStatementRequest, + _request: Request, + ) -> Result { + counter!("requests", "endpoint" => "do_action_create_prepared_statement").increment(1); + let ActionCreatePreparedStatementRequest { query, .. } = query; + let res = self + .query_to_create_prepared_statement_action(query) + .await + .map_err(|e| Status::internal(e.to_string()))?; + Ok(res) + } + async fn do_get_statement( &self, ticket: TicketStatementQuery, diff --git a/src/server/http/mod.rs b/src/server/http/mod.rs index 124c049..4d5ea59 100644 --- a/src/server/http/mod.rs +++ b/src/server/http/mod.rs @@ -103,10 +103,10 @@ impl HttpApp { .await { Ok(_) => { - info!("Shutting down app") + info!("shutting down app") } Err(_) => { - panic!("Error serving HTTP app") + panic!("error serving HTTP app") } } } @@ -133,7 +133,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { let mut app_execution = AppExecution::new(execution_ctx); #[cfg(feature = "flightsql")] { - info!("Setting up FlightSQLContext"); + info!("setting up FlightSQLContext"); let auth = AuthConfig { basic_auth: config.flightsql_client.auth.basic_auth.clone(), bearer_token: config.flightsql_client.auth.bearer_token.clone(), @@ -155,7 +155,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { app_execution.with_flightsql_ctx(flightsql_context); } } - debug!("Created AppExecution: {app_execution:?}"); + debug!("created AppExecution: {app_execution:?}"); let (addr, metrics_addr) = if let Some(cmd) = cli.command.clone() { match cmd { Command::ServeHttp { From b78070dcb6917543f07a96170ce37b78ea8572fe Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Mon, 12 May 2025 09:23:22 -0400 Subject: [PATCH 02/16] A little more cleanup --- src/cli/mod.rs | 13 ++++++++++++- src/server/flightsql/service.rs | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/cli/mod.rs b/src/cli/mod.rs index dc0369b..7641907 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -150,7 +150,18 @@ impl CliApp { Ok(()) } - FlightSqlCommand::CreatePreparedStatement { query } => Ok(()), + FlightSqlCommand::CreatePreparedStatement { query } => { + let prepared = self + .app_execution + .flightsql_ctx() + .create_prepared_statement(query) + .await?; + let dataset_schema = prepared.dataset_schema()?; + let parameter_schema = prepared.parameter_schema()?; + println!("Created prepared statement with schema:\n{dataset_schema:?}"); + println!("Parameters:\n{parameter_schema:?}"); + Ok(()) + } } } diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index ce5437f..3b04b37 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -31,7 +31,7 @@ use arrow_flight::{ use color_eyre::Result; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::ipc::writer::IpcWriteOptions; -use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, Prepare}; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; use datafusion::prelude::{col, lit}; use datafusion::sql::parser::DFParser; use datafusion_app::local::ExecutionContext; From 3501b89fae4e243091a72b47e2ab5058b5350090 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Tue, 13 May 2025 10:27:34 -0400 Subject: [PATCH 03/16] Create working better --- Cargo.lock | 1 + crates/datafusion-app/Cargo.toml | 1 + crates/datafusion-app/src/flightsql.rs | 41 ++++++++++++++++++++------ crates/datafusion-app/src/local.rs | 2 ++ src/args.rs | 1 + src/cli/mod.rs | 10 +++---- src/server/flightsql/service.rs | 13 ++++++-- 7 files changed, 52 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9054638..ff95c90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1477,6 +1477,7 @@ dependencies = [ "object_store_opendal", "opendal", "parking_lot", + "prost", "serde", "tokio", "tokio-stream", diff --git a/crates/datafusion-app/Cargo.toml b/crates/datafusion-app/Cargo.toml index ebeb8df..409f2cb 100644 --- a/crates/datafusion-app/Cargo.toml +++ b/crates/datafusion-app/Cargo.toml @@ -31,6 +31,7 @@ opendal = { version = "0.51", features = [ "services-huggingface", ], optional = true } parking_lot = "0.12.3" +prost = "0.13.5" serde = { version = "1.0.197", features = ["derive"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] } tokio-stream = { version = "0.1.15", features = ["net"] } diff --git a/crates/datafusion-app/src/flightsql.rs b/crates/datafusion-app/src/flightsql.rs index d17de13..629aa39 100644 --- a/crates/datafusion-app/src/flightsql.rs +++ b/crates/datafusion-app/src/flightsql.rs @@ -20,21 +20,22 @@ use std::sync::Arc; use arrow_flight::{ decode::FlightRecordBatchStream, sql::{ - client::{FlightSqlServiceClient, PreparedStatement}, - ActionCreatePreparedStatementRequest, CommandGetDbSchemas, CommandGetTables, + client::FlightSqlServiceClient, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, Any, CommandGetDbSchemas, CommandGetTables, + ProstMessageExt, }, - FlightInfo, + Action, FlightInfo, }; #[cfg(feature = "flightsql")] use base64::engine::{general_purpose::STANDARD, Engine as _}; +use color_eyre::eyre::{self, Result}; use datafusion::{ error::{DataFusionError, Result as DFResult}, physical_plan::stream::RecordBatchStreamAdapter, sql::parser::DFParser, }; use log::{debug, error, info, warn}; - -use color_eyre::eyre::{self, Result}; +use prost::Message; use tokio::sync::Mutex; use tokio_stream::StreamExt; use tonic::{transport::Channel, IntoRequest}; @@ -46,6 +47,9 @@ use crate::{ config::FlightSQLConfig, flightsql_benchmarks::FlightSQLBenchmarkStats, ExecOptions, ExecResult, }; +pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; +pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; + pub type FlightSQLClient = Arc>>>; #[derive(Clone, Debug, Default)] @@ -281,14 +285,33 @@ impl FlightSQLContext { pub async fn create_prepared_statement( &self, query: String, - ) -> DFResult> { + ) -> DFResult { let client = Arc::clone(&self.client); let mut guard = client.lock().await; if let Some(client) = guard.as_mut() { - client - .prepare(query, None) + let cmd = ActionCreatePreparedStatementRequest { + query, + transaction_id: None, + }; + let action = Action { + r#type: CREATE_PREPARED_STATEMENT.to_string(), + body: cmd.as_any().encode_to_vec().into(), + }; + let mut result = client + .inner_mut() + .do_action(action) .await - .map_err(|e| DataFusionError::ArrowError(e, None)) + .map_err(|e| DataFusionError::External(e.to_string().into()))? + .into_inner(); + let result = result + .message() + .await + .map_err(|e| DataFusionError::External(e.to_string().into()))? + .unwrap(); + let any = Any::decode(&*result.body) + .map_err(|e| DataFusionError::External(e.to_string().into()))?; + let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap(); + Ok(prepared_result) } else { Err(DataFusionError::External( "No FlightSQL client configured. Add one in `~/.config/dft/config.toml`".into(), diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index 0962333..855f47b 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -70,6 +70,8 @@ pub struct ExecutionContext { /// Observability handlers #[cfg(feature = "observability")] observability: ObservabilityContext, + /// Map of prepared statements where the key is the id of the prepared statement and the value + /// is [`datafusion::logical_expr::Prepare`] that can be reused. } impl std::fmt::Debug for ExecutionContext { diff --git a/src/args.rs b/src/args.rs index fd27dab..3bcd492 100644 --- a/src/args.rs +++ b/src/args.rs @@ -162,6 +162,7 @@ pub enum FlightSqlCommand { #[clap(long)] table_types: Option>, }, + /// Creates a prepared statement on the server CreatePreparedStatement { /// The query for the prepared statement #[clap(long)] diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 7641907..93a1f4e 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -85,6 +85,8 @@ impl CliApp { #[cfg(feature = "flightsql")] async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> { + use arrow_flight::IpcMessage; + use datafusion::arrow::datatypes::Schema; use futures::stream; match command { @@ -151,15 +153,13 @@ impl CliApp { } FlightSqlCommand::CreatePreparedStatement { query } => { - let prepared = self + let prepared_result = self .app_execution .flightsql_ctx() .create_prepared_statement(query) .await?; - let dataset_schema = prepared.dataset_schema()?; - let parameter_schema = prepared.parameter_schema()?; - println!("Created prepared statement with schema:\n{dataset_schema:?}"); - println!("Parameters:\n{parameter_schema:?}"); + let handle = String::from_utf8(prepared_result.prepared_statement_handle.to_vec())?; + println!("Created prepared statement: {handle}"); Ok(()) } } diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index 3b04b37..75d1bf9 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -19,11 +19,11 @@ use crate::execution::AppExecution; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; -use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; use arrow_flight::sql::{ ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, - CommandGetCatalogs, CommandGetDbSchemas, CommandGetTables, CommandStatementQuery, SqlInfo, - TicketStatementQuery, + CommandGetCatalogs, CommandGetDbSchemas, CommandGetTables, CommandPreparedStatementQuery, + CommandStatementQuery, DoPutPreparedStatementResult, SqlInfo, TicketStatementQuery, }; use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket, @@ -435,6 +435,13 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(res) } + async fn do_put_prepared_statement_query( + &self, + query: CommandPreparedStatementQuery, + _request: Request, + ) -> Result { + } + async fn do_get_statement( &self, ticket: TicketStatementQuery, From 4085c41eea37b0fc621ac5e9fa584c8760fa7ffd Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Tue, 13 May 2025 12:09:51 -0400 Subject: [PATCH 04/16] Fix --- crates/datafusion-app/src/local.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index 855f47b..d79851d 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -17,12 +17,13 @@ //! [`ExecutionContext`]: DataFusion based execution context for running SQL queries +use std::collections::HashMap; use std::io::Write; use std::path::PathBuf; use std::sync::Arc; use color_eyre::eyre::eyre; -use datafusion::logical_expr::LogicalPlan; +use datafusion::logical_expr::{LogicalPlan, Prepare}; use futures::TryFutureExt; use log::{debug, error, info}; @@ -72,6 +73,7 @@ pub struct ExecutionContext { observability: ObservabilityContext, /// Map of prepared statements where the key is the id of the prepared statement and the value /// is [`datafusion::logical_expr::Prepare`] that can be reused. + prepared_statements: HashMap, } impl std::fmt::Debug for ExecutionContext { From 39d7eb73fe90a60d89be971db0c4d0b235f5c24c Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 14 May 2025 08:16:57 -0400 Subject: [PATCH 05/16] Go --- crates/datafusion-app/src/local.rs | 3 +++ src/server/flightsql/service.rs | 14 ++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index d79851d..ac447de 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -151,6 +151,7 @@ impl ExecutionContext { ddl_path: config.ddl_path.as_ref().map(PathBuf::from), executor, observability, + prepared_statements: HashMap::new(), } } #[cfg(not(feature = "observability"))] @@ -160,6 +161,7 @@ impl ExecutionContext { session_ctx, ddl_path: config.ddl_path.as_ref().map(PathBuf::from), executor, + prepared_statements: HashMap::new(), } } }; @@ -185,6 +187,7 @@ impl ExecutionContext { executor: None, #[cfg(feature = "observability")] observability, + prepared_statements: HashMap::new(), } } diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index 75d1bf9..0b750c7 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -426,6 +426,8 @@ impl FlightSqlService for FlightSqlServiceImpl { query: ActionCreatePreparedStatementRequest, _request: Request, ) -> Result { + // TODO: Add metadata option to request header which allows stateless option where all + // information required to execute query is passed back to client counter!("requests", "endpoint" => "do_action_create_prepared_statement").increment(1); let ActionCreatePreparedStatementRequest { query, .. } = query; let res = self @@ -435,12 +437,12 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(res) } - async fn do_put_prepared_statement_query( - &self, - query: CommandPreparedStatementQuery, - _request: Request, - ) -> Result { - } + // async fn do_put_prepared_statement_query( + // &self, + // query: CommandPreparedStatementQuery, + // _request: Request, + // ) -> Result { + // } async fn do_get_statement( &self, From 2801a43782477fd28b855d2ad0a6aed8c504afc3 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 14 May 2025 08:35:36 -0400 Subject: [PATCH 06/16] A little cleanup --- tests/extension_cases/flightsql.rs | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/extension_cases/flightsql.rs b/tests/extension_cases/flightsql.rs index e8f26de..3e5aa38 100644 --- a/tests/extension_cases/flightsql.rs +++ b/tests/extension_cases/flightsql.rs @@ -660,7 +660,7 @@ async fn test_query_command() { .unwrap() .arg("flightsql") .arg("statement-query") - .arg("--sql") + .arg("--query") .arg(sql.clone()) .timeout(Duration::from_secs(5)) .assert() @@ -1054,3 +1054,31 @@ async fn test_get_tables_table_type() { fixture.shutdown_and_wait().await; } + +#[tokio::test] +async fn test_create_prepared_statement() { + let ctx = ExecutionContext::test(); + let exec = AppExecution::new(ctx); + let test_server = FlightSqlServiceImpl::new(exec); + let fixture = TestFixture::new(test_server.service(), "127.0.0.1:50051").await; + + let assert = tokio::task::spawn_blocking(|| { + Command::cargo_bin("dft") + .unwrap() + .arg("flightsql") + .arg("create-prepared-statement") + .arg("--query") + .arg("SELECT 1") + .timeout(Duration::from_secs(5)) + .assert() + .success() + }) + .await + .unwrap(); + + let expected = r#"Created prepared statement"#; + + assert.stdout(contains_str(expected)); + + fixture.shutdown_and_wait().await; +} From e4af6f4a32ab2a18f3c5a62668832773b864c8dd Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 14 May 2025 09:33:09 -0400 Subject: [PATCH 07/16] Move prepared statement to app --- crates/datafusion-app/src/lib.rs | 1 + .../datafusion-app/src/prepared_statement.rs | 30 +++++++++++++++++++ src/cli/mod.rs | 7 +++-- src/server/flightsql/mod.rs | 1 - src/server/flightsql/service.rs | 7 ++++- 5 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 crates/datafusion-app/src/prepared_statement.rs diff --git a/crates/datafusion-app/src/lib.rs b/crates/datafusion-app/src/lib.rs index ba5635b..fdf3d30 100644 --- a/crates/datafusion-app/src/lib.rs +++ b/crates/datafusion-app/src/lib.rs @@ -27,6 +27,7 @@ pub mod local; pub mod local_benchmarks; #[cfg(feature = "observability")] pub mod observability; +pub mod prepared_statement; pub mod sql_utils; pub mod stats; pub mod tables; diff --git a/crates/datafusion-app/src/prepared_statement.rs b/crates/datafusion-app/src/prepared_statement.rs new file mode 100644 index 0000000..eadccfc --- /dev/null +++ b/crates/datafusion-app/src/prepared_statement.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use prost::Message; + +#[derive(Clone, Message)] +pub struct PreparedStatementHandle { + #[prost(string)] + pub id: String, +} + +impl PreparedStatementHandle { + pub fn new(id: String) -> Self { + Self { id } + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 93a1f4e..f9eef03 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -45,6 +45,7 @@ use { flightsql::FlightSQLContext, flightsql_benchmarks::FlightSQLBenchmarkStats, }, + prost::Message, tonic::IntoRequest, }; @@ -87,6 +88,7 @@ impl CliApp { async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> { use arrow_flight::IpcMessage; use datafusion::arrow::datatypes::Schema; + use datafusion_app::prepared_statement::PreparedStatementHandle; use futures::stream; match command { @@ -158,8 +160,9 @@ impl CliApp { .flightsql_ctx() .create_prepared_statement(query) .await?; - let handle = String::from_utf8(prepared_result.prepared_statement_handle.to_vec())?; - println!("Created prepared statement: {handle}"); + let handle = + PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; + println!("Created prepared statement: {}", handle.id); Ok(()) } } diff --git a/src/server/flightsql/mod.rs b/src/server/flightsql/mod.rs index f4c675b..10d1ea1 100644 --- a/src/server/flightsql/mod.rs +++ b/src/server/flightsql/mod.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - pub mod service; use crate::args::{Command, DftArgs}; diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index 0b750c7..cec3a79 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -36,6 +36,7 @@ use datafusion::prelude::{col, lit}; use datafusion::sql::parser::DFParser; use datafusion_app::local::ExecutionContext; use datafusion_app::observability::ObservabilityRequestDetails; +use datafusion_app::prepared_statement::PreparedStatementHandle; use futures::{StreamExt, TryStreamExt}; use jiff::Timestamp; use log::{debug, error, info}; @@ -239,9 +240,10 @@ impl FlightSqlServiceImpl { .map_err(|e| Status::internal(e.to_string()))?; let parameters_schema_as_ipc = SchemaAsIpc::new(¶meters_schema, &opts); let IpcMessage(parameters_bytes) = IpcMessage::try_from(parameters_schema_as_ipc)?; + let handle = PreparedStatementHandle::new(id); debug!("serialized prepared statement"); let res = ActionCreatePreparedStatementResult { - prepared_statement_handle: id.into_bytes().into(), + prepared_statement_handle: handle.encode_to_vec().into(), dataset_schema: dataset_bytes, parameter_schema: parameters_bytes, }; @@ -442,6 +444,9 @@ impl FlightSqlService for FlightSqlServiceImpl { // query: CommandPreparedStatementQuery, // _request: Request, // ) -> Result { + // let CommandPreparedStatementQuery { + // prepared_statement_handle, + // } = query; // } async fn do_get_statement( From d5ad84f5184e0ef0585d5ff65ad23224e5393139 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 14 May 2025 10:08:42 -0400 Subject: [PATCH 08/16] Add rustup component --- .github/actions/setup-rust/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/setup-rust/action.yml b/.github/actions/setup-rust/action.yml index 346c550..c3b0636 100644 --- a/.github/actions/setup-rust/action.yml +++ b/.github/actions/setup-rust/action.yml @@ -33,6 +33,7 @@ runs: rustup install stable rustup default stable rustup component add clippy + rustup component add rustfmt - name: Install Cargo tools shell: bash run: | From 31e05124ff9f7372f0e1cb12e7cada9fa0764140 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Thu, 15 May 2025 10:18:22 -0400 Subject: [PATCH 09/16] Add tests and start adding parameter binding --- crates/datafusion-app/src/local.rs | 21 ++++-- src/server/flightsql/service.rs | 105 ++++++++++++++++++++++++----- 2 files changed, 105 insertions(+), 21 deletions(-) diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index ac447de..2d0af90 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -26,6 +26,7 @@ use color_eyre::eyre::eyre; use datafusion::logical_expr::{LogicalPlan, Prepare}; use futures::TryFutureExt; use log::{debug, error, info}; +use parking_lot::RwLock; use crate::catalog::create_app_catalog; use crate::config::ExecutionConfig; @@ -72,8 +73,8 @@ pub struct ExecutionContext { #[cfg(feature = "observability")] observability: ObservabilityContext, /// Map of prepared statements where the key is the id of the prepared statement and the value - /// is [`datafusion::logical_expr::Prepare`] that can be reused. - prepared_statements: HashMap, + /// is [`datafusion::logical_expr::LogicalPlan`] that can be reused. + prepared_statements: Arc>>, } impl std::fmt::Debug for ExecutionContext { @@ -151,7 +152,7 @@ impl ExecutionContext { ddl_path: config.ddl_path.as_ref().map(PathBuf::from), executor, observability, - prepared_statements: HashMap::new(), + prepared_statements: Arc::new(RwLock::new(HashMap::new())), } } #[cfg(not(feature = "observability"))] @@ -161,7 +162,7 @@ impl ExecutionContext { session_ctx, ddl_path: config.ddl_path.as_ref().map(PathBuf::from), executor, - prepared_statements: HashMap::new(), + prepared_statements: Arc::new(RwLock::new(HashMap::new())), } } }; @@ -187,7 +188,7 @@ impl ExecutionContext { executor: None, #[cfg(feature = "observability")] observability, - prepared_statements: HashMap::new(), + prepared_statements: Arc::new(RwLock::new(HashMap::new())), } } @@ -209,6 +210,16 @@ impl ExecutionContext { &self.executor } + pub fn prepared_statements(&self) -> Arc>> { + Arc::clone(&self.prepared_statements) + } + + pub fn insert_prepared_statement(&self, id: String, logical_plan: LogicalPlan) { + let prepared_statements = Arc::clone(&self.prepared_statements); + let mut prepared_statements = prepared_statements.write(); + prepared_statements.insert(id, logical_plan); + } + /// Return the `ObservabilityCtx` #[cfg(feature = "observability")] pub fn observability(&self) -> &ObservabilityContext { diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index cec3a79..a00342d 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -16,6 +16,7 @@ // under the License. use crate::execution::AppExecution; +use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; @@ -29,6 +30,7 @@ use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket, }; use color_eyre::Result; +use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; @@ -70,6 +72,10 @@ impl FlightSqlServiceImpl { FlightServiceServer::new(self.clone()) } + pub fn execution_ctx(&self) -> &ExecutionContext { + &self.execution + } + async fn do_get_common_handler( &self, request_id: String, @@ -211,12 +217,13 @@ impl FlightSqlServiceImpl { } } - async fn query_to_create_prepared_statement_action( + pub(crate) async fn query_to_create_prepared_statement_action( &self, query: String, ) -> Result { let ctx = self.execution.session_ctx(); let state = ctx.state(); + // Create the prepared statement plan and extract parameters let logical_plan = state.create_logical_plan(&query).await?; let query_schema = logical_plan.schema().as_arrow().clone(); let (parameter_data_types, parameter_fields): (Vec, Vec) = logical_plan @@ -227,13 +234,17 @@ impl FlightSqlServiceImpl { let parameters_schema = Schema::new(parameter_fields); let id = Uuid::new_v4().to_string(); - let builder = LogicalPlanBuilder::new(logical_plan); - let prepared = builder - .prepare(id.clone(), parameter_data_types) - .map_err(|e| Status::internal(e.to_string()))? - .build()?; - ctx.execute_logical_plan(prepared).await?.collect().await?; + // Make sure the prepared statement builds + // let builder = LogicalPlanBuilder::new(logical_plan); + // let prepared = builder + // .prepare(id.clone(), parameter_data_types) + // .map_err(|e| Status::internal(e.to_string()))? + // .build()?; + // By default save it to `ExecutionContext` state for later use + self.execution + .insert_prepared_statement(id.clone(), logical_plan.clone()); debug!("created prepared statement with id {id}"); + // Serialize the prepared statement let opts = IpcWriteOptions::default(); let dataset_schema_as_ipc = SchemaAsIpc::new(&query_schema, &opts); let IpcMessage(dataset_bytes) = IpcMessage::try_from(dataset_schema_as_ipc) @@ -439,15 +450,34 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(res) } - // async fn do_put_prepared_statement_query( - // &self, - // query: CommandPreparedStatementQuery, - // _request: Request, - // ) -> Result { - // let CommandPreparedStatementQuery { - // prepared_statement_handle, - // } = query; - // } + async fn do_put_prepared_statement_query( + &self, + query: CommandPreparedStatementQuery, + request: Request, + ) -> Result { + let CommandPreparedStatementQuery { + prepared_statement_handle, + } = query; + let handle = PreparedStatementHandle::decode(prepared_statement_handle) + .map_err(|e| Status::internal(e.to_string()))?; + let stream = request.into_inner(); + // The first message contains the flight descriptor and the schema. + // Read the flight descriptor without discarding the schema: + let flight_descriptor: FlightDescriptor = stream + .peek() + .await + .cloned() + .transpose()? + .and_then(|data| data.flight_descriptor) + .expect("first message should contain flight descriptor"); + + // Pass the stream through a decoder + let batches: Vec = FlightRecordBatchStream::new_from_flight_data( + request.into_inner().map_err(|e| e.into()), + ) + .try_collect() + .await?; + } async fn do_get_statement( &self, @@ -515,3 +545,46 @@ fn try_request_id_from_request(request: Request) -> Result { )?; Ok(request_id) } + +#[cfg(test)] +mod tests { + use datafusion_app::local::ExecutionContext; + + use crate::execution::AppExecution; + + use super::FlightSqlServiceImpl; + + fn setup() -> FlightSqlServiceImpl { + let execution = ExecutionContext::test(); + let app_execution = AppExecution::new(execution); + FlightSqlServiceImpl::new(app_execution) + } + + #[tokio::test] + async fn test_create_prepared_statement_without_params() { + let flight = setup(); + let query = "SELECT NOW()".to_string(); + let res = flight + .query_to_create_prepared_statement_action(query) + .await; + assert!(res.is_ok()); + let execution = flight.execution_ctx(); + let prepared_statements = execution.prepared_statements(); + let prepared_statements = prepared_statements.read(); + assert!(prepared_statements.len() == 1); + } + + #[tokio::test] + async fn test_create_prepared_statement_with_param() { + let flight = setup(); + let query = "SELECT * FROM information_schema.tables WHERE table_type = $1"; + let res = flight + .query_to_create_prepared_statement_action(query.to_string()) + .await; + assert!(res.is_ok()); + let execution = flight.execution_ctx(); + let prepared_statements = execution.prepared_statements(); + let prepared_statements = prepared_statements.read(); + assert!(prepared_statements.len() == 1); + } +} From 26901f3eddaf38cb153c127110227f8a3dc41d3e Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sat, 17 May 2025 11:20:06 -0400 Subject: [PATCH 10/16] More work on putting params --- .../datafusion-app/src/prepared_statement.rs | 11 +- src/args.rs | 11 +- src/cli/mod.rs | 14 ++- src/server/flightsql/service.rs | 103 +++++++++++++++--- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/crates/datafusion-app/src/prepared_statement.rs b/crates/datafusion-app/src/prepared_statement.rs index eadccfc..17caa5f 100644 --- a/crates/datafusion-app/src/prepared_statement.rs +++ b/crates/datafusion-app/src/prepared_statement.rs @@ -20,11 +20,16 @@ use prost::Message; #[derive(Clone, Message)] pub struct PreparedStatementHandle { #[prost(string)] - pub id: String, + pub prepared_id: String, + #[prost(string)] + pub request_id: String, } impl PreparedStatementHandle { - pub fn new(id: String) -> Self { - Self { id } + pub fn new(prepared_id: String, request_id: String) -> Self { + Self { + prepared_id, + request_id, + } } } diff --git a/src/args.rs b/src/args.rs index 3bcd492..f58f69b 100644 --- a/src/args.rs +++ b/src/args.rs @@ -19,9 +19,13 @@ use crate::config::get_data_dir; use clap::{Parser, Subcommand}; +use datafusion::{arrow::datatypes::DataType, common::ParamValues, scalar::ScalarValue}; #[cfg(any(feature = "http", feature = "flightsql"))] use std::net::SocketAddr; -use std::path::{Path, PathBuf}; +use std::{ + collections::HashMap, + path::{Path, PathBuf}, +}; const LONG_ABOUT: &str = " dft - DataFusion TUI @@ -168,6 +172,11 @@ pub enum FlightSqlCommand { #[clap(long)] query: String, }, + /// Bind parameters to a prepared statement + DoPutPreparedStatementQuery { + /// The parameters to be bound + params_list: Vec, + }, } #[derive(Clone, Debug, Subcommand)] diff --git a/src/cli/mod.rs b/src/cli/mod.rs index f9eef03..c71c2a1 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -162,7 +162,19 @@ impl CliApp { .await?; let handle = PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; - println!("Created prepared statement: {}", handle.id); + println!("created prepared statement: {}", handle.prepared_id); + Ok(()) + } + + FlightSqlCommand::DoPutPreparedStatementQuery { query } => { + let prepared_result = self + .app_execution + .flightsql_ctx() + .create_prepared_statement(query) + .await?; + let handle = + PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; + println!("created prepared statement: {}", handle.prepared_id); Ok(()) } } diff --git a/src/server/flightsql/service.rs b/src/server/flightsql/service.rs index a00342d..1cf8442 100644 --- a/src/server/flightsql/service.rs +++ b/src/server/flightsql/service.rs @@ -29,12 +29,16 @@ use arrow_flight::sql::{ use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, IpcMessage, SchemaAsIpc, Ticket, }; +use color_eyre::eyre::ensure; use color_eyre::Result; use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::compute::concat_batches; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::common::ParamValues; use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; use datafusion::prelude::{col, lit}; +use datafusion::scalar::ScalarValue; use datafusion::sql::parser::DFParser; use datafusion_app::local::ExecutionContext; use datafusion_app::observability::ObservabilityRequestDetails; @@ -52,7 +56,7 @@ use uuid::Uuid; #[derive(Clone)] pub struct FlightSqlServiceImpl { - requests: Arc>>, + requests: Arc>>, execution: ExecutionContext, } @@ -89,7 +93,7 @@ impl FlightSqlServiceImpl { .requests .lock() .map_err(|_| Status::internal("Failed to acquire lock on requests"))?; - guard.get(&id).cloned() + guard.get(&id.to_string()).cloned() }; if let Some(plan) = maybe_plan { let stream = self @@ -175,7 +179,7 @@ impl FlightSqlServiceImpl { .requests .lock() .map_err(|_| Status::internal("failed to acquire lock on requests"))?; - guard.insert(request_id, logical_plan); + guard.insert(request_id.to_string(), logical_plan); Ok(Response::new(info)) } else { @@ -233,7 +237,11 @@ impl FlightSqlServiceImpl { .collect(); let parameters_schema = Schema::new(parameter_fields); - let id = Uuid::new_v4().to_string(); + // A single prepared statement can be reused many times so it gets its own id which can be + // used to invoke it + let prepared_id = Uuid::new_v4().to_string(); + // Id specific to the request + let request_id = Uuid::new_v4().to_string(); // Make sure the prepared statement builds // let builder = LogicalPlanBuilder::new(logical_plan); // let prepared = builder @@ -242,8 +250,8 @@ impl FlightSqlServiceImpl { // .build()?; // By default save it to `ExecutionContext` state for later use self.execution - .insert_prepared_statement(id.clone(), logical_plan.clone()); - debug!("created prepared statement with id {id}"); + .insert_prepared_statement(prepared_id.clone(), logical_plan.clone()); + debug!("created prepared statement with id {prepared_id}"); // Serialize the prepared statement let opts = IpcWriteOptions::default(); let dataset_schema_as_ipc = SchemaAsIpc::new(&query_schema, &opts); @@ -251,7 +259,7 @@ impl FlightSqlServiceImpl { .map_err(|e| Status::internal(e.to_string()))?; let parameters_schema_as_ipc = SchemaAsIpc::new(¶meters_schema, &opts); let IpcMessage(parameters_bytes) = IpcMessage::try_from(parameters_schema_as_ipc)?; - let handle = PreparedStatementHandle::new(id); + let handle = PreparedStatementHandle::new(prepared_id, request_id); debug!("serialized prepared statement"); let res = ActionCreatePreparedStatementResult { prepared_statement_handle: handle.encode_to_vec().into(), @@ -460,10 +468,10 @@ impl FlightSqlService for FlightSqlServiceImpl { } = query; let handle = PreparedStatementHandle::decode(prepared_statement_handle) .map_err(|e| Status::internal(e.to_string()))?; - let stream = request.into_inner(); + let mut stream = request.into_inner(); // The first message contains the flight descriptor and the schema. // Read the flight descriptor without discarding the schema: - let flight_descriptor: FlightDescriptor = stream + stream .peek() .await .cloned() @@ -472,11 +480,30 @@ impl FlightSqlService for FlightSqlServiceImpl { .expect("first message should contain flight descriptor"); // Pass the stream through a decoder - let batches: Vec = FlightRecordBatchStream::new_from_flight_data( - request.into_inner().map_err(|e| e.into()), - ) - .try_collect() - .await?; + let batches: Vec = + FlightRecordBatchStream::new_from_flight_data(stream.map_err(|e| e.into())) + .try_collect() + .await?; + let params = params_from_batches(batches)?; + let plans = &self.execution.prepared_statements(); + let guard = plans.read(); + if let Some(plan) = guard.get(&handle.prepared_id) { + // We make a copy of the prepared statement that we can use for the request + let plan_with_params = plan.clone().with_param_values(params)?; + let requests = Arc::clone(&self.requests); + let mut guard = requests.lock().unwrap(); + guard.insert(handle.request_id, plan_with_params); + let res = DoPutPreparedStatementResult { + prepared_statement_handle: None, + }; + Ok(res) + } else { + Err(Status::internal(format!( + "no prepared statement for id {}", + handle.prepared_id + )) + .into()) + } } async fn do_get_statement( @@ -546,6 +573,54 @@ fn try_request_id_from_request(request: Request) -> Result { Ok(request_id) } +fn params_from_batches(batches: Vec) -> Result { + ensure!( + batches.len() >= 1, + "at least one record batch of params must be provided" + ); + let concatted = concat_batches(&batches[0].schema(), &batches)?; + record_batch_to_param_values(&concatted) +} + +// Copied from https://github.com/datafusion-contrib/datafusion-flight-sql-server/blob/4b6dcea299b7afe0eee3e7e46b6c83f513afb990/datafusion-flight-sql-server/src/service.rs#L1113C1-L1149C2 +// Converts a record batch with a single row into ParamValues +fn record_batch_to_param_values(batch: &RecordBatch) -> Result { + let mut param_values: Vec<(String, Option, ScalarValue)> = Vec::new(); + + let mut is_list = true; + for col_index in 0..batch.num_columns() { + let array = batch.column(col_index); + let scalar = ScalarValue::try_from_array(array, 0)?; + let name = batch + .schema_ref() + .field(col_index) + .name() + .trim_start_matches('$') + .to_string(); + let index = name.parse().ok(); + is_list &= index.is_some(); + param_values.push((name, index, scalar)); + } + if is_list { + let mut values: Vec<(Option, ScalarValue)> = param_values + .into_iter() + .map(|(_name, index, value)| (index, value)) + .collect(); + values.sort_by_key(|(index, _value)| *index); + Ok(values + .into_iter() + .map(|(_index, value)| value) + .collect::>() + .into()) + } else { + Ok(param_values + .into_iter() + .map(|(name, _index, value)| (name, value)) + .collect::>() + .into()) + } +} + #[cfg(test)] mod tests { use datafusion_app::local::ExecutionContext; From 64abbc0a2da8f903672e287a908eeeb5d7e0f0bb Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Mon, 19 May 2025 10:11:44 -0400 Subject: [PATCH 11/16] Binding params --- crates/datafusion-app/src/flightsql.rs | 45 +++++++++++++++++++++++--- src/args.rs | 2 ++ src/cli/mod.rs | 23 +++++++++---- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/crates/datafusion-app/src/flightsql.rs b/crates/datafusion-app/src/flightsql.rs index 629aa39..31ad1ba 100644 --- a/crates/datafusion-app/src/flightsql.rs +++ b/crates/datafusion-app/src/flightsql.rs @@ -19,28 +19,30 @@ use std::sync::Arc; use arrow_flight::{ decode::FlightRecordBatchStream, + encode::FlightDataEncoderBuilder, sql::{ client::FlightSqlServiceClient, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, CommandGetDbSchemas, CommandGetTables, - ProstMessageExt, + CommandPreparedStatementQuery, ProstMessageExt, }, - Action, FlightInfo, + Action, FlightDescriptor, FlightInfo, PutResult, }; -#[cfg(feature = "flightsql")] use base64::engine::{general_purpose::STANDARD, Engine as _}; use color_eyre::eyre::{self, Result}; use datafusion::{ + arrow::array::RecordBatch, + common::ParamValues, error::{DataFusionError, Result as DFResult}, physical_plan::stream::RecordBatchStreamAdapter, sql::parser::DFParser, }; +use futures::{stream, TryStreamExt}; use log::{debug, error, info, warn}; use prost::Message; use tokio::sync::Mutex; use tokio_stream::StreamExt; use tonic::{transport::Channel, IntoRequest}; -#[cfg(feature = "flightsql")] use crate::config::BasicAuth; use crate::{ @@ -319,6 +321,41 @@ impl FlightSQLContext { } } + pub async fn bind_prepared_statement_params( + &self, + prepared_id: String, + params: RecordBatch, + ) -> DFResult> { + let client = Arc::clone(&self.client); + let mut guard = client.lock().await; + if let Some(client) = guard.as_mut() { + let cmd = CommandPreparedStatementQuery { + prepared_statement_handle: self.handle.clone(), + }; + + let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); + let flight_stream_builder = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .with_schema(params.schema()); + let flight_data = flight_stream_builder + .build(futures::stream::iter(Ok(params))) + .try_collect::>() + .await + .map_err(|e| DataFusionError::External(e.to_string().into()))?; + + client + .do_put(stream::iter(flight_data)) + .await? + .message() + .await + .map_err(|e| DataFusionError::External(e.to_string().into())) + } else { + Err(DataFusionError::External( + "No FlightSQL client configured. Add one in `~/.config/dft/config.toml`".into(), + )) + } + } + pub async fn do_get(&self, flight_info: FlightInfo) -> DFResult> { let client = Arc::clone(&self.client); let mut guard = client.lock().await; diff --git a/src/args.rs b/src/args.rs index f58f69b..0ad3ffa 100644 --- a/src/args.rs +++ b/src/args.rs @@ -174,6 +174,8 @@ pub enum FlightSqlCommand { }, /// Bind parameters to a prepared statement DoPutPreparedStatementQuery { + /// The prepared statement id to bind parameters to + prepared_id: String, /// The parameters to be bound params_list: Vec, }, diff --git a/src/cli/mod.rs b/src/cli/mod.rs index c71c2a1..d759926 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -166,15 +166,24 @@ impl CliApp { Ok(()) } - FlightSqlCommand::DoPutPreparedStatementQuery { query } => { - let prepared_result = self + FlightSqlCommand::DoPutPreparedStatementQuery { + prepared_id, + params_list, + } => { + let params_batch = params_list_to_params_batch(); + let put_result = self .app_execution .flightsql_ctx() - .create_prepared_statement(query) - .await?; - let handle = - PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; - println!("created prepared statement: {}", handle.prepared_id); + .bind_prepared_statement_params(prepared_id, params_batch) + .await; + // let prepared_result = self + // .app_execution + // .flightsql_ctx() + // .create_prepared_statement(query) + // .await?; + // let handle = + // PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; + // println!("created prepared statement: {}", handle.prepared_id); Ok(()) } } From 6196220105997a115d401ccaed7fa058b4b6365b Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Tue, 20 May 2025 09:00:59 -0400 Subject: [PATCH 12/16] Start towards param binding --- crates/datafusion-app/src/catalog/mod.rs | 23 ++++++++++++++++++ crates/datafusion-app/src/flightsql.rs | 2 +- crates/datafusion-app/src/local.rs | 2 +- src/cli/mod.rs | 30 +++++++++++++++++------- src/execution.rs | 2 +- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/crates/datafusion-app/src/catalog/mod.rs b/crates/datafusion-app/src/catalog/mod.rs index 45eec3d..8d8ccaa 100644 --- a/crates/datafusion-app/src/catalog/mod.rs +++ b/crates/datafusion-app/src/catalog/mod.rs @@ -41,9 +41,32 @@ pub fn create_app_catalog( catalog.register_schema("meta", Arc::::clone(&meta_schema))?; let versions_table = try_create_meta_versions_table(app_name, app_version)?; meta_schema.register_table("versions".to_string(), versions_table)?; + #[cfg(feature = "flightsql")] + { + let flightsql_schema = Arc::new(MemorySchemaProvider::new()); + } Ok(Arc::new(catalog)) } +// fn create_flightsql_prepared_statements_table() -> Result> { +// let fields = vec![ +// Field::new("datafusion", DataType::Utf8, false), +// Field::new("datafusion-app", DataType::Utf8, false), +// ]; +// let schema = Arc::new(Schema::new(fields)); +// +// let batches = RecordBatch::try_new( +// Arc::::clone(&schema), +// vec![ +// Arc::new(app_version_arr), +// Arc::new(datafusion_version_arr), +// Arc::new(datafusion_app_version_arr), +// ], +// )?; +// +// Ok(Arc::new(MemTable::try_new(schema, vec![vec![batches]])?)) +// } + fn try_create_meta_versions_table(app_name: &str, app_version: &str) -> Result> { let fields = vec![ Field::new(app_name, DataType::Utf8, false), diff --git a/crates/datafusion-app/src/flightsql.rs b/crates/datafusion-app/src/flightsql.rs index 31ad1ba..0494e98 100644 --- a/crates/datafusion-app/src/flightsql.rs +++ b/crates/datafusion-app/src/flightsql.rs @@ -338,7 +338,7 @@ impl FlightSQLContext { .with_flight_descriptor(Some(descriptor)) .with_schema(params.schema()); let flight_data = flight_stream_builder - .build(futures::stream::iter(Ok(params))) + .build(futures::stream::iter([Ok(params)])) .try_collect::>() .await .map_err(|e| DataFusionError::External(e.to_string().into()))?; diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index 2d0af90..a4cc6f4 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -23,7 +23,7 @@ use std::path::PathBuf; use std::sync::Arc; use color_eyre::eyre::eyre; -use datafusion::logical_expr::{LogicalPlan, Prepare}; +use datafusion::logical_expr::LogicalPlan; use futures::TryFutureExt; use log::{debug, error, info}; use parking_lot::RwLock; diff --git a/src/cli/mod.rs b/src/cli/mod.rs index d759926..dfb44fc 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -84,6 +84,18 @@ impl CliApp { Ok(()) } + async fn save_prepared_result( + &self, + create_prepared_result: ActionCreatePreparedStatementResult, + ) -> Result<()> { + let prepared_statements = self + .app_execution + .execution_ctx() + .session_ctx() + .table() + .await?; + } + #[cfg(feature = "flightsql")] async fn handle_flightsql_command(&self, command: FlightSqlCommand) -> color_eyre::Result<()> { use arrow_flight::IpcMessage; @@ -160,6 +172,7 @@ impl CliApp { .flightsql_ctx() .create_prepared_statement(query) .await?; + self.save_prepared_result(prepared_result); let handle = PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; println!("created prepared statement: {}", handle.prepared_id); @@ -170,20 +183,12 @@ impl CliApp { prepared_id, params_list, } => { - let params_batch = params_list_to_params_batch(); + let params_batch = params_list_to_params_batch(params_list); let put_result = self .app_execution .flightsql_ctx() .bind_prepared_statement_params(prepared_id, params_batch) .await; - // let prepared_result = self - // .app_execution - // .flightsql_ctx() - // .create_prepared_statement(query) - // .await?; - // let handle = - // PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; - // println!("created prepared statement: {}", handle.prepared_id); Ok(()) } } @@ -735,3 +740,10 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { app.execute_files_or_commands().await?; Ok(()) } + +fn params_list_to_params_batch(params: Vec) -> RecordBatch { + let fields = params.iter().enumerate().map(|(i, _s)| { + let name = format!("${}", i + 1); + Field::new() + }); +} diff --git a/src/execution.rs b/src/execution.rs index 708c483..f2662a3 100644 --- a/src/execution.rs +++ b/src/execution.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; use datafusion_app::flightsql::{FlightSQLClient, FlightSQLContext}; use datafusion_app::{local::ExecutionContext, ExecOptions, ExecResult}; -/// Provides all core execution functionality for execution queries from either a local +/// Provides all core execution functionality for executing queries from either a local /// `SessionContext` or a remote `FlightSQL` service #[derive(Clone, Debug)] pub struct AppExecution { From 02c3b72b644c9de9f29a8ff8e27943087bc41917 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 21 May 2025 09:48:56 -0400 Subject: [PATCH 13/16] Go --- crates/datafusion-app/src/catalog/mod.rs | 1 + src/cli/mod.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/crates/datafusion-app/src/catalog/mod.rs b/crates/datafusion-app/src/catalog/mod.rs index 8d8ccaa..e27c2c2 100644 --- a/crates/datafusion-app/src/catalog/mod.rs +++ b/crates/datafusion-app/src/catalog/mod.rs @@ -44,6 +44,7 @@ pub fn create_app_catalog( #[cfg(feature = "flightsql")] { let flightsql_schema = Arc::new(MemorySchemaProvider::new()); + catalog.register_schema("flightsql", flightsql_schema)?; } Ok(Arc::new(catalog)) } diff --git a/src/cli/mod.rs b/src/cli/mod.rs index dfb44fc..c7bcb8f 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -172,6 +172,7 @@ impl CliApp { .flightsql_ctx() .create_prepared_statement(query) .await?; + self.save_prepared_result(prepared_result); let handle = PreparedStatementHandle::decode(prepared_result.prepared_statement_handle)?; From f4331cdb2f06097f5562bbd6a4d7c1b3de7b3a9e Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Thu, 22 May 2025 09:34:23 -0400 Subject: [PATCH 14/16] Move db to df app --- crates/datafusion-app/Cargo.toml | 2 +- crates/datafusion-app/src/config.rs | 4 ++ {src => crates/datafusion-app/src}/db.rs | 33 ++++++++++++- crates/datafusion-app/src/lib.rs | 1 + crates/datafusion-app/src/local.rs | 6 +++ src/cli/mod.rs | 3 +- src/config.rs | 62 ++++++++++++------------ src/lib.rs | 2 +- src/server/flightsql/mod.rs | 3 +- src/server/http/mod.rs | 3 +- src/tui/mod.rs | 3 +- 11 files changed, 80 insertions(+), 42 deletions(-) rename {src => crates/datafusion-app/src}/db.rs (96%) diff --git a/crates/datafusion-app/Cargo.toml b/crates/datafusion-app/Cargo.toml index 409f2cb..d444231 100644 --- a/crates/datafusion-app/Cargo.toml +++ b/crates/datafusion-app/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "1.0.197", features = ["derive"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] } tokio-stream = { version = "0.1.15", features = ["net"] } tonic = { version = "0.12.3", optional = true } -url = { version = "2.5.2", optional = true } +url = { version = "2.5.2", features = ["serde"], optional = true } [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } diff --git a/crates/datafusion-app/src/config.rs b/crates/datafusion-app/src/config.rs index b1d3875..b2f52ab 100644 --- a/crates/datafusion-app/src/config.rs +++ b/crates/datafusion-app/src/config.rs @@ -24,6 +24,8 @@ use datafusion_udfs_wasm::WasmInputDataType; use serde::Deserialize; use std::collections::HashMap; +use crate::db::{default_db_config, DbConfig}; + #[cfg(feature = "s3")] use { color_eyre::Result, @@ -94,6 +96,7 @@ pub struct ExecutionConfig { #[cfg(feature = "observability")] #[serde(default)] pub observability: ObservabilityConfig, + pub db: DbConfig, } impl Default for ExecutionConfig { @@ -111,6 +114,7 @@ impl Default for ExecutionConfig { catalog: default_catalog(), #[cfg(feature = "observability")] observability: default_observability(), + db: default_db_config(), } } } diff --git a/src/db.rs b/crates/datafusion-app/src/db.rs similarity index 96% rename from src/db.rs rename to crates/datafusion-app/src/db.rs index 659c278..312afb5 100644 --- a/src/db.rs +++ b/crates/datafusion-app/src/db.rs @@ -27,8 +27,39 @@ use datafusion::{ prelude::SessionContext, }; use log::info; +use serde::Deserialize; +use url::Url; -use crate::config::DbConfig; +#[derive(Debug, Clone, Deserialize)] +pub struct DbConfig { + #[serde(default = "default_db_path")] + pub path: Url, +} + +impl Default for DbConfig { + fn default() -> Self { + default_db_config() + } +} + +pub fn default_db_config() -> DbConfig { + DbConfig { + path: default_db_path(), + } +} + +fn default_db_path() -> Url { + let base = directories::BaseDirs::new().expect("Base directories should be available"); + let path = base + .data_dir() + .to_path_buf() + .join("dft/") + .to_str() + .unwrap() + .to_string(); + let with_schema = format!("file://{path}"); + Url::parse(&with_schema).unwrap() +} pub async fn register_db(ctx: &SessionContext, db_config: &DbConfig) -> Result<()> { info!("registering tables to database"); diff --git a/crates/datafusion-app/src/lib.rs b/crates/datafusion-app/src/lib.rs index fdf3d30..77801ee 100644 --- a/crates/datafusion-app/src/lib.rs +++ b/crates/datafusion-app/src/lib.rs @@ -17,6 +17,7 @@ pub mod catalog; pub mod config; +pub mod db; pub mod executor; pub mod extensions; #[cfg(feature = "flightsql")] diff --git a/crates/datafusion-app/src/local.rs b/crates/datafusion-app/src/local.rs index a4cc6f4..b92e523 100644 --- a/crates/datafusion-app/src/local.rs +++ b/crates/datafusion-app/src/local.rs @@ -30,6 +30,7 @@ use parking_lot::RwLock; use crate::catalog::create_app_catalog; use crate::config::ExecutionConfig; +use crate::db::register_db; use crate::{ExecOptions, ExecResult}; use color_eyre::eyre::{self, Result}; use datafusion::common::Result as DFResult; @@ -170,6 +171,11 @@ impl ExecutionContext { Ok(ctx) } + pub async fn register_db(&self) -> Result<()> { + register_db(&self.session_ctx, &self.config.db).await; + Ok(()) + } + /// Useful for testing execution functionality pub fn test() -> Self { let cfg = SessionConfig::new().with_information_schema(true); diff --git a/src/cli/mod.rs b/src/cli/mod.rs index c7bcb8f..0160925 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -17,7 +17,6 @@ //! [`CliApp`]: Command Line User Interface use crate::config::AppConfig; -use crate::db::register_db; use crate::{args::DftArgs, execution::AppExecution}; use color_eyre::eyre::eyre; use color_eyre::Result; @@ -717,6 +716,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { crate::APP_NAME, env!("CARGO_PKG_VERSION"), )?; + execution_ctx.register_db().await?; #[allow(unused_mut)] let mut app_execution = AppExecution::new(execution_ctx); #[cfg(feature = "flightsql")] @@ -736,7 +736,6 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { app_execution.with_flightsql_ctx(flightsql_ctx); } } - register_db(app_execution.session_ctx(), &config.db).await?; let app = CliApp::new(app_execution, cli.clone()); app.execute_files_or_commands().await?; Ok(()) diff --git a/src/config.rs b/src/config.rs index c2128b4..4ce9749 100644 --- a/src/config.rs +++ b/src/config.rs @@ -192,37 +192,37 @@ fn default_interaction_config() -> InteractionConfig { InteractionConfig::default() } -#[derive(Debug, Clone, Deserialize)] -pub struct DbConfig { - #[serde(default = "default_db_path")] - pub path: Url, -} - -impl Default for DbConfig { - fn default() -> Self { - default_db_config() - } -} - -fn default_db_config() -> DbConfig { - DbConfig { - path: default_db_path(), - } -} - -#[allow(unused)] -fn default_db_path() -> Url { - let base = directories::BaseDirs::new().expect("Base directories should be available"); - let path = base - .data_dir() - .to_path_buf() - .join("dft/") - .to_str() - .unwrap() - .to_string(); - let with_schema = format!("file://{path}"); - Url::parse(&with_schema).unwrap() -} +// #[derive(Debug, Clone, Deserialize)] +// pub struct DbConfig { +// #[serde(default = "default_db_path")] +// pub path: Url, +// } +// +// impl Default for DbConfig { +// fn default() -> Self { +// default_db_config() +// } +// } +// +// fn default_db_config() -> DbConfig { +// DbConfig { +// path: default_db_path(), +// } +// } +// +// #[allow(unused)] +// fn default_db_path() -> Url { +// let base = directories::BaseDirs::new().expect("Base directories should be available"); +// let path = base +// .data_dir() +// .to_path_buf() +// .join("dft/") +// .to_str() +// .unwrap() +// .to_string(); +// let with_schema = format!("file://{path}"); +// Url::parse(&with_schema).unwrap() +// } #[derive(Clone, Debug, Deserialize)] pub struct DisplayConfig { diff --git a/src/lib.rs b/src/lib.rs index 375ea08..7ef962e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ pub mod args; pub mod cli; pub mod config; -pub mod db; +// pub mod db; pub mod execution; #[cfg(any(feature = "flightsql", feature = "http"))] pub mod server; diff --git a/src/server/flightsql/mod.rs b/src/server/flightsql/mod.rs index 10d1ea1..6a2a451 100644 --- a/src/server/flightsql/mod.rs +++ b/src/server/flightsql/mod.rs @@ -18,7 +18,6 @@ pub mod service; use crate::args::{Command, DftArgs}; use crate::config::AppConfig; -use crate::db::register_db; use crate::execution::AppExecution; use color_eyre::{eyre::eyre, Result}; use datafusion_app::config::merge_configs; @@ -185,6 +184,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { if cli.run_ddl { execution_ctx.execute_ddl().await; } + execution_ctx.register_db().await?; let app_execution = AppExecution::new(execution_ctx); let (addr, metrics_addr) = if let Some(cmd) = cli.command.clone() { @@ -219,7 +219,6 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { config.flightsql_server.server_metrics_addr, ) }; - register_db(app_execution.session_ctx(), &config.db).await?; let app = FlightSqlApp::try_new(app_execution, &config, addr, metrics_addr).await?; app.run().await; Ok(()) diff --git a/src/server/http/mod.rs b/src/server/http/mod.rs index 4d5ea59..d19b9b8 100644 --- a/src/server/http/mod.rs +++ b/src/server/http/mod.rs @@ -23,7 +23,6 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use crate::{ args::{Command, DftArgs}, config::AppConfig, - db::register_db, execution::AppExecution, }; use axum::Router; @@ -128,6 +127,7 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { if cli.run_ddl { execution_ctx.execute_ddl().await; } + execution_ctx.register_db().await?; #[allow(unused_mut)] let mut app_execution = AppExecution::new(execution_ctx); @@ -188,7 +188,6 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { config.http_server.server_metrics_addr, ) }; - register_db(app_execution.session_ctx(), &config.db).await?; let app = HttpApp::try_new(app_execution, config.clone(), addr, metrics_addr).await?; app.run().await; diff --git a/src/tui/mod.rs b/src/tui/mod.rs index 39a587d..a3a5abf 100644 --- a/src/tui/mod.rs +++ b/src/tui/mod.rs @@ -45,7 +45,6 @@ use tokio_util::sync::CancellationToken; use self::execution::{ExecutionError, ExecutionResultsBatch, TuiExecution}; use self::handlers::{app_event_handler, crossterm_event_handler}; use crate::config::AppConfig; -use crate::db::register_db; use crate::telemetry; use crate::{args::DftArgs, execution::AppExecution}; use datafusion_app::sql_utils::clean_sql; @@ -381,8 +380,8 @@ pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> { crate::APP_NAME, env!("CARGO_PKG_VERSION"), )?; + execution_ctx.register_db().await?; let app_execution = AppExecution::new(execution_ctx); - register_db(app_execution.session_ctx(), &config.db).await?; let app = App::new(state, cli, app_execution); app.run_app().await?; Ok(()) From ae9a3699c9d560d6fbfdb14192b55fba8e2a1034 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Fri, 23 May 2025 08:45:40 -0400 Subject: [PATCH 15/16] More --- src/cli/mod.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 0160925..d5d63a5 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -39,6 +39,7 @@ use std::path::{Path, PathBuf}; #[cfg(feature = "flightsql")] use { crate::args::{Command, FlightSqlCommand}, + arrow_flight::sql::ActionCreatePreparedStatementResult, datafusion_app::{ config::{AuthConfig, FlightSQLConfig}, flightsql::FlightSQLContext, @@ -83,16 +84,34 @@ impl CliApp { Ok(()) } + #[cfg(feature = "flightsql")] async fn save_prepared_result( &self, create_prepared_result: ActionCreatePreparedStatementResult, ) -> Result<()> { + use datafusion::{ + datasource::provider_as_source, + logical_expr::{dml::InsertOp, LogicalPlanBuilder}, + sql::TableReference, + }; + + let prepared_statements_table = + TableReference::full("dft", "flightsql", "prepared_statements"); let prepared_statements = self .app_execution .execution_ctx() .session_ctx() - .table() + .table(prepared_statements_table) .await?; + let (state, logical_plan) = prepared_statements.into_parts(); + let builder = LogicalPlanBuilder::new(logical_plan); + let insert_plan = LogicalPlanBuilder::insert_into( + input, + prepared_statements_table, + provider_as_source(prepared_statements), + InsertOp::Append, + ); + Ok(()) } #[cfg(feature = "flightsql")] From 3e8417d13753587602fba12ca31c7e444f493646 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 28 May 2025 10:15:38 -0400 Subject: [PATCH 16/16] Updates --- Cargo.lock | 2 ++ crates/datafusion-app/Cargo.toml | 1 + crates/datafusion-app/src/catalog/mod.rs | 23 +++++++++++-- crates/datafusion-app/src/tables/map_table.rs | 2 ++ src/config.rs | 32 ------------------- 5 files changed, 26 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff95c90..0ba7874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1479,6 +1479,7 @@ dependencies = [ "parking_lot", "prost", "serde", + "serde_json", "tokio", "tokio-stream", "tonic", @@ -4900,6 +4901,7 @@ version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ + "indexmap 2.8.0", "itoa", "memchr", "ryu", diff --git a/crates/datafusion-app/Cargo.toml b/crates/datafusion-app/Cargo.toml index d444231..6c59987 100644 --- a/crates/datafusion-app/Cargo.toml +++ b/crates/datafusion-app/Cargo.toml @@ -33,6 +33,7 @@ opendal = { version = "0.51", features = [ parking_lot = "0.12.3" prost = "0.13.5" serde = { version = "1.0.197", features = ["derive"] } +serde_json = { version = "1.0.140", features = ["indexmap"] } tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] } tokio-stream = { version = "0.1.15", features = ["net"] } tonic = { version = "0.12.3", optional = true } diff --git a/crates/datafusion-app/src/catalog/mod.rs b/crates/datafusion-app/src/catalog/mod.rs index e27c2c2..80865e0 100644 --- a/crates/datafusion-app/src/catalog/mod.rs +++ b/crates/datafusion-app/src/catalog/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::{collections::HashMap, fs::File, path::Path, sync::Arc}; use datafusion::{ arrow::{ @@ -26,13 +26,18 @@ use datafusion::{ catalog::{CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider}, common::Result, datasource::MemTable, + error::DataFusionError, + scalar::ScalarValue, DATAFUSION_VERSION, }; +use indexmap::IndexMap; use crate::config::ExecutionConfig; +type PreparedStatementsMap = IndexMap>; + pub fn create_app_catalog( - _config: &ExecutionConfig, + config: &ExecutionConfig, app_name: &str, app_version: &str, ) -> Result> { @@ -45,6 +50,20 @@ pub fn create_app_catalog( { let flightsql_schema = Arc::new(MemorySchemaProvider::new()); catalog.register_schema("flightsql", flightsql_schema)?; + let db_path = config.db.path.to_file_path().map_err(|_| { + DataFusionError::External("error converting DB path to file path".to_string().into()) + })?; + let prepared_statements_file = db_path + .join(app_name) + .join("flightsql") + .join("prepared_statements"); + let prepared_statements = if let Ok(true) = prepared_statements_file.try_exists() { + let reader = File::open(prepared_statements_file) + .map_err(|e| DataFusionError::External(e.to_string().into()))?; + let vals: PreparedStatementsMap = serde_json::from_reader(reader) + .map_err(|e| DataFusionError::External(e.to_string().into()))?; + } else { + }; } Ok(Arc::new(catalog)) } diff --git a/crates/datafusion-app/src/tables/map_table.rs b/crates/datafusion-app/src/tables/map_table.rs index 359c276..5e08960 100644 --- a/crates/datafusion-app/src/tables/map_table.rs +++ b/crates/datafusion-app/src/tables/map_table.rs @@ -48,6 +48,8 @@ type ArrayBuilderRef = Box; // The first String key is meant to hold primary key and provide O(1) lookup. The inner HashMap is // for holding arbitrary column and value pairs - the key is the column name and we use DataFusions // scalar value to provide dynamic typing for the column values. +// +// Todo: Maybe the inner HashMap should be a Vec to make projecting easier type MapData = Arc>>>; #[derive(Debug)] diff --git a/src/config.rs b/src/config.rs index 4ce9749..978fa0e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -192,38 +192,6 @@ fn default_interaction_config() -> InteractionConfig { InteractionConfig::default() } -// #[derive(Debug, Clone, Deserialize)] -// pub struct DbConfig { -// #[serde(default = "default_db_path")] -// pub path: Url, -// } -// -// impl Default for DbConfig { -// fn default() -> Self { -// default_db_config() -// } -// } -// -// fn default_db_config() -> DbConfig { -// DbConfig { -// path: default_db_path(), -// } -// } -// -// #[allow(unused)] -// fn default_db_path() -> Url { -// let base = directories::BaseDirs::new().expect("Base directories should be available"); -// let path = base -// .data_dir() -// .to_path_buf() -// .join("dft/") -// .to_str() -// .unwrap() -// .to_string(); -// let with_schema = format!("file://{path}"); -// Url::parse(&with_schema).unwrap() -// } - #[derive(Clone, Debug, Deserialize)] pub struct DisplayConfig { #[serde(default = "default_frame_rate")]