Skip to content

Commit 1f160d6

Browse files
committed
feat: add support for insert into
1 parent 0823267 commit 1f160d6

File tree

6 files changed

+273
-70
lines changed

6 files changed

+273
-70
lines changed

ballista/client/tests/context_checks.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,4 +399,69 @@ mod supported {
399399

400400
assert_batches_eq!(expected, &result);
401401
}
402+
403+
#[rstest]
404+
#[case::standalone(standalone_context())]
405+
#[case::remote(remote_context())]
406+
#[tokio::test]
407+
#[cfg(not(windows))] // test is failing at windows, can't debug it
408+
async fn should_support_sql_insert_into(
409+
#[future(awt)]
410+
#[case]
411+
ctx: SessionContext,
412+
test_data: String,
413+
) {
414+
ctx.register_parquet(
415+
"test",
416+
&format!("{test_data}/alltypes_plain.parquet"),
417+
Default::default(),
418+
)
419+
.await
420+
.unwrap();
421+
422+
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
423+
let write_dir_path = write_dir
424+
.path()
425+
.to_str()
426+
.expect("path to be converted to str");
427+
428+
ctx.sql("select * from test")
429+
.await
430+
.unwrap()
431+
.write_parquet(write_dir_path, Default::default(), Default::default())
432+
.await
433+
.unwrap();
434+
435+
ctx.register_parquet("written_table", write_dir_path, Default::default())
436+
.await
437+
.unwrap();
438+
439+
ctx.sql("INSERT INTO written_table select * from test")
440+
.await
441+
.unwrap()
442+
.show()
443+
.await
444+
.unwrap();
445+
446+
let result = ctx
447+
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
448+
.await.unwrap()
449+
.collect()
450+
.await.unwrap();
451+
452+
let expected = [
453+
"+----+------------+---------------------+",
454+
"| id | string_col | timestamp_col |",
455+
"+----+------------+---------------------+",
456+
"| 5 | 31 | 2009-03-01T00:01:00 |",
457+
"| 5 | 31 | 2009-03-01T00:01:00 |",
458+
"| 6 | 30 | 2009-04-01T00:00:00 |",
459+
"| 6 | 30 | 2009-04-01T00:00:00 |",
460+
"| 7 | 31 | 2009-04-01T00:01:00 |",
461+
"| 7 | 31 | 2009-04-01T00:01:00 |",
462+
"+----+------------+---------------------+",
463+
];
464+
465+
assert_batches_eq!(expected, &result);
466+
}
402467
}

ballista/client/tests/context_unsupported.rs

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -144,71 +144,6 @@ mod unsupported {
144144
"+----+----------+---------------------+",
145145
];
146146

147-
assert_batches_eq!(expected, &result);
148-
}
149-
#[rstest]
150-
#[case::standalone(standalone_context())]
151-
#[case::remote(remote_context())]
152-
#[tokio::test]
153-
#[should_panic]
154-
// "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"
155-
async fn should_support_sql_insert_into(
156-
#[future(awt)]
157-
#[case]
158-
ctx: SessionContext,
159-
test_data: String,
160-
) {
161-
ctx.register_parquet(
162-
"test",
163-
&format!("{test_data}/alltypes_plain.parquet"),
164-
Default::default(),
165-
)
166-
.await
167-
.unwrap();
168-
let write_dir = tempfile::tempdir().expect("temporary directory to be created");
169-
let write_dir_path = write_dir
170-
.path()
171-
.to_str()
172-
.expect("path to be converted to str");
173-
174-
ctx.sql("select * from test")
175-
.await
176-
.unwrap()
177-
.write_parquet(write_dir_path, Default::default(), Default::default())
178-
.await
179-
.unwrap();
180-
181-
ctx.register_parquet("written_table", write_dir_path, Default::default())
182-
.await
183-
.unwrap();
184-
185-
let _ = ctx
186-
.sql("INSERT INTO written_table select * from written_table")
187-
.await
188-
.unwrap()
189-
.collect()
190-
.await
191-
.unwrap();
192-
193-
let result = ctx
194-
.sql("select id, string_col, timestamp_col from written_table where id > 4 order by id")
195-
.await.unwrap()
196-
.collect()
197-
.await.unwrap();
198-
199-
let expected = [
200-
"+----+------------+---------------------+",
201-
"| id | string_col | timestamp_col |",
202-
"+----+------------+---------------------+",
203-
"| 5 | 31 | 2009-03-01T00:01:00 |",
204-
"| 5 | 31 | 2009-03-01T00:01:00 |",
205-
"| 6 | 30 | 2009-04-01T00:00:00 |",
206-
"| 6 | 30 | 2009-04-01T00:00:00 |",
207-
"| 7 | 31 | 2009-04-01T00:01:00 |",
208-
"| 7 | 31 | 2009-04-01T00:01:00 |",
209-
"+----+------------+---------------------+",
210-
];
211-
212147
assert_batches_eq!(expected, &result);
213148
}
214149
}

ballista/core/src/config.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli
3232
/// max message size for gRPC clients
3333
pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
3434
"ballista.grpc_client_max_message_size";
35+
/// enable or disable ballista dml planner extension.
36+
/// when enabled planner will use custom logical planner DML
37+
/// extension which will serialize table provider used in DML
38+
///
39+
/// this configuration should be disabled if using remote schema
40+
/// registries.
41+
pub const BALLISTA_PLANNER_DML_EXTENSION: &str = "ballista.planner.dml_extension";
3542

3643
pub type ParseResult<T> = result::Result<T, String>;
3744
use std::sync::LazyLock;
@@ -48,6 +55,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(||
4855
"Configuration for max message size in gRPC clients".to_string(),
4956
DataType::UInt64,
5057
Some((16 * 1024 * 1024).to_string())),
58+
ConfigEntry::new(BALLISTA_PLANNER_DML_EXTENSION.to_string(),
59+
"Enable ballista planner DML extension".to_string(),
60+
DataType::Boolean,
61+
Some((true).to_string())),
5162
];
5263
entries
5364
.into_iter()
@@ -165,6 +176,10 @@ impl BallistaConfig {
165176
self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM)
166177
}
167178

179+
pub fn planner_dml_extension(&self) -> bool {
180+
self.get_bool_setting(BALLISTA_PLANNER_DML_EXTENSION)
181+
}
182+
168183
fn get_usize_setting(&self, key: &str) -> usize {
169184
if let Some(v) = self.settings.get(key) {
170185
// infallible because we validate all configs in the constructor

ballista/core/src/planner.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
use crate::config::BallistaConfig;
1919
use crate::execution_plans::DistributedQueryExec;
20-
use crate::serde::BallistaLogicalExtensionCodec;
20+
use crate::serde::{BallistaDmlExtension, BallistaLogicalExtensionCodec};
2121

2222
use async_trait::async_trait;
2323
use datafusion::arrow::datatypes::Schema;
24+
use datafusion::common::plan_err;
2425
use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor};
26+
use datafusion::datasource::DefaultTableSource;
2527
use datafusion::error::DataFusionError;
2628
use datafusion::execution::context::{QueryPlanner, SessionState};
27-
use datafusion::logical_expr::{LogicalPlan, TableScan};
29+
use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan};
2830
use datafusion::physical_plan::empty::EmptyExec;
2931
use datafusion::physical_plan::ExecutionPlan;
3032
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
@@ -125,6 +127,41 @@ impl<T: 'static + AsLogicalPlan> QueryPlanner for BallistaQueryPlanner<T> {
125127
log::debug!("create_physical_plan - handling empty exec");
126128
Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))
127129
}
130+
// At the moment DML statement uses TableReference instead of TableProvider.
131+
// As ballista has two contexts (client and scheduler) scheduler context may not
132+
// know table provider for given table reference, thus we need to attach
133+
// table provider to this DML statement.
134+
LogicalPlan::Dml(DmlStatement { table_name, .. })
135+
if self.config.planner_dml_extension() =>
136+
{
137+
let table_name = table_name.to_owned();
138+
let table = table_name.table().to_string();
139+
let schema = session_state.schema_for_ref(table_name.clone())?;
140+
let table_provider = match schema.table(&table).await? {
141+
Some(ref provider) => Ok(Arc::clone(provider)),
142+
_ => plan_err!("No table named '{table}'"),
143+
}?;
144+
145+
let table_source = Arc::new(DefaultTableSource::new(table_provider));
146+
let table =
147+
TableScan::try_new(table_name, table_source, None, vec![], None)?;
148+
149+
// custom made logical extension node is used to attach table reference
150+
let node = Arc::new(BallistaDmlExtension {
151+
dml: logical_plan.clone(),
152+
table,
153+
});
154+
let plan = LogicalPlan::Extension(Extension { node });
155+
log::debug!("create_physical_plan - handling DML statement");
156+
157+
Ok(Arc::new(DistributedQueryExec::<T>::with_extension(
158+
self.scheduler_url.clone(),
159+
self.config.clone(),
160+
plan.clone(),
161+
self.extension_codec.clone(),
162+
session_state.session_id().to_string(),
163+
)))
164+
}
128165
_ => {
129166
log::debug!("create_physical_plan - handling general statement");
130167

ballista/core/src/serde/mod.rs

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
2222

2323
use arrow_flight::sql::ProstMessageExt;
2424
use datafusion::arrow::datatypes::SchemaRef;
25-
use datafusion::common::{DataFusionError, Result};
25+
use datafusion::common::{plan_err, DataFusionError, Result};
2626
use datafusion::execution::FunctionRegistry;
27+
use datafusion::logical_expr::{
28+
Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore,
29+
};
2730
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
2831
use datafusion_proto::logical_plan::file_formats::{
2932
ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec,
@@ -179,15 +182,64 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec {
179182
inputs: &[datafusion::logical_expr::LogicalPlan],
180183
ctx: &datafusion::prelude::SessionContext,
181184
) -> Result<datafusion::logical_expr::Extension> {
182-
self.default_codec.try_decode(buf, inputs, ctx)
185+
match BallistaExtensionProto::decode(buf) {
186+
Ok(extension) => match extension.extension {
187+
Some(BallistaExtensionType::Dml(BallistaDmlExtensionProto {
188+
dml: Some(dml),
189+
table: Some(table),
190+
})) => {
191+
let table = table.try_into_logical_plan(ctx, self)?;
192+
match table {
193+
LogicalPlan::TableScan(scan) => {
194+
let dml = dml.try_into_logical_plan(ctx, self)?;
195+
Ok(Extension {
196+
node: Arc::new(BallistaDmlExtension { dml, table: scan }),
197+
})
198+
}
199+
_ => plan_err!(
200+
"TableScan expected in ballista DML extension definition"
201+
),
202+
}
203+
}
204+
None => plan_err!("Ballista extension can't be None"),
205+
_ => plan_err!("Ballista extension not supported"),
206+
},
207+
208+
Err(_e) => self.default_codec.try_decode(buf, inputs, ctx),
209+
}
183210
}
184211

185212
fn try_encode(
186213
&self,
187214
node: &datafusion::logical_expr::Extension,
188215
buf: &mut Vec<u8>,
189216
) -> Result<()> {
190-
self.default_codec.try_encode(node, buf)
217+
if let Some(BallistaDmlExtension { dml: input, table }) =
218+
node.node.as_any().downcast_ref::<BallistaDmlExtension>()
219+
{
220+
let input = LogicalPlanNode::try_from_logical_plan(input, self)?;
221+
222+
let table = LogicalPlanNode::try_from_logical_plan(
223+
&LogicalPlan::TableScan(table.clone()),
224+
self,
225+
)?;
226+
let extension = BallistaDmlExtensionProto {
227+
dml: Some(input),
228+
table: Some(table),
229+
};
230+
231+
let extension = BallistaExtensionProto {
232+
extension: Some(BallistaExtensionType::Dml(extension)),
233+
};
234+
235+
extension
236+
.encode(buf)
237+
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
238+
239+
Ok(())
240+
} else {
241+
self.default_codec.try_encode(node, buf)
242+
}
191243
}
192244

193245
fn try_decode_table_provider(
@@ -487,6 +539,74 @@ struct FileFormatProto {
487539
pub blob: Vec<u8>,
488540
}
489541

542+
#[derive(Clone, PartialEq, prost::Message)]
543+
struct BallistaExtensionProto {
544+
#[prost(oneof = "BallistaExtensionType", tags = "1")]
545+
extension: Option<BallistaExtensionType>,
546+
}
547+
548+
#[derive(Clone, PartialEq, ::prost::Oneof)]
549+
enum BallistaExtensionType {
550+
#[prost(message, tag = "1")]
551+
Dml(BallistaDmlExtensionProto),
552+
}
553+
554+
#[derive(Clone, PartialEq, prost::Message)]
555+
struct BallistaDmlExtensionProto {
556+
#[prost(message, tag = 1)]
557+
pub dml: Option<LogicalPlanNode>,
558+
#[prost(message, tag = 2)]
559+
pub table: Option<LogicalPlanNode>,
560+
}
561+
562+
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
563+
pub struct BallistaDmlExtension {
564+
/// LogicalPlan::DML
565+
/// DMLStatement is expected
566+
pub dml: LogicalPlan,
567+
/// Table provider which is referenced
568+
/// from LogicalPlan::DML
569+
pub table: TableScan,
570+
}
571+
572+
impl std::cmp::PartialOrd for BallistaDmlExtension {
573+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
574+
self.dml.partial_cmp(&other.dml)
575+
}
576+
}
577+
impl UserDefinedLogicalNodeCore for BallistaDmlExtension {
578+
fn name(&self) -> &str {
579+
"BallistaDmlExtension"
580+
}
581+
582+
fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> {
583+
vec![&self.dml]
584+
}
585+
586+
fn schema(&self) -> &datafusion::common::DFSchemaRef {
587+
self.dml.schema()
588+
}
589+
590+
fn expressions(&self) -> Vec<datafusion::prelude::Expr> {
591+
self.dml.expressions()
592+
}
593+
594+
fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
595+
self.dml.fmt(f)
596+
}
597+
598+
fn with_exprs_and_inputs(
599+
&self,
600+
_exprs: Vec<datafusion::prelude::Expr>,
601+
inputs: Vec<datafusion::logical_expr::LogicalPlan>,
602+
) -> Result<Self> {
603+
Ok(Self {
604+
dml: inputs[0].clone(),
605+
table: self.table.clone(),
606+
})
607+
}
608+
}
609+
490610
#[cfg(test)]
491611
mod test {
492612
use super::*;

0 commit comments

Comments
 (0)