Skip to content

Commit 23a80da

Browse files
authored
feat: Union types coercion (#87)
1 parent cb4e782 commit 23a80da

File tree

16 files changed

+262
-44
lines changed

16 files changed

+262
-44
lines changed

datafusion-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repository = "https://github.com/apache/arrow-datafusion"
2828
rust-version = "1.59"
2929

3030
[dependencies]
31-
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095" }
31+
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9" }
3232
clap = { version = "3", features = ["derive", "cargo"] }
3333
datafusion = { path = "../datafusion/core", version = "7.0.0" }
3434
dirs = "4.0.0"

datafusion-examples/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ path = "examples/avro_sql.rs"
3434
required-features = ["datafusion/avro"]
3535

3636
[dev-dependencies]
37-
arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095" }
37+
arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9" }
3838
async-trait = "0.1.41"
3939
datafusion = { path = "../datafusion/core" }
4040
futures = "0.3"

datafusion/common/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ jit = ["cranelift-module"]
3838
pyarrow = ["pyo3"]
3939

4040
[dependencies]
41-
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095", features = ["prettyprint"] }
41+
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9", features = ["prettyprint"] }
4242
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
4343
cranelift-module = { version = "0.82.0", optional = true }
4444
ordered-float = "2.10"
45-
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095", features = ["arrow"], optional = true }
45+
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9", features = ["arrow"], optional = true }
4646
pyo3 = { version = "0.16", optional = true }
4747
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "4b67b56c590ce5d4b0a7706e28ffa4209dadc770" }

datafusion/core/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ unicode_expressions = ["datafusion-physical-expr/regex_expressions"]
5555

5656
[dependencies]
5757
ahash = { version = "0.7", default-features = false }
58-
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095", features = ["prettyprint"] }
58+
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9", features = ["prettyprint"] }
5959
async-trait = "0.1.41"
6060
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
6161
chrono = { version = "0.4", default-features = false }
@@ -66,13 +66,14 @@ datafusion-jit = { path = "../jit", version = "7.0.0", optional = true }
6666
datafusion-physical-expr = { path = "../physical-expr", version = "7.0.0" }
6767
futures = "0.3"
6868
hashbrown = { version = "0.12", features = ["raw"] }
69+
itertools = "0.10"
6970
lazy_static = { version = "^1.4.0" }
7071
log = "^0.4"
7172
num-traits = { version = "0.2", optional = true }
7273
num_cpus = "1.13.0"
7374
ordered-float = "2.10"
7475
parking_lot = "0.12"
75-
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095", features = ["arrow"] }
76+
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9", features = ["arrow"] }
7677
paste = "^1.0"
7778
pin-project-lite= "^0.2.7"
7879
pyo3 = { version = "0.16", optional = true }

datafusion/core/fuzz-utils/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ edition = "2021"
2323
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
2424

2525
[dependencies]
26-
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0bc721700352afe8267dba82388b81deffc24095", features = ["prettyprint"] }
26+
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "0a23deccf4e589768177fbaa7a3745c8b25f63c9", features = ["prettyprint"] }
2727
env_logger = "0.9.0"
2828
rand = "0.8"

datafusion/core/src/logical_plan/builder.rs

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ use arrow::{
3636
record_batch::RecordBatch,
3737
};
3838
use datafusion_data_access::object_store::ObjectStore;
39+
use datafusion_physical_expr::coercion_rule::binary_rule::comparison_eq_coercion;
3940
use std::convert::TryFrom;
4041
use std::iter;
4142
use std::{
4243
collections::{HashMap, HashSet},
4344
sync::Arc,
4445
};
4546

46-
use super::dfschema::ToDFSchema;
47+
use super::{dfschema::ToDFSchema, expr_rewriter::coerce_plan_expr_for_schema};
4748
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
4849
use crate::logical_plan::{
4950
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
@@ -1069,39 +1070,59 @@ pub fn union_with_alias(
10691070
right_plan: LogicalPlan,
10701071
alias: Option<String>,
10711072
) -> Result<LogicalPlan> {
1072-
let union_schema = left_plan.schema().clone();
1073-
let inputs_iter = vec![left_plan, right_plan]
1073+
let union_schema = (0..left_plan.schema().fields().len())
1074+
.map(|i| {
1075+
let left_field = left_plan.schema().field(i);
1076+
let right_field = right_plan.schema().field(i);
1077+
let nullable = left_field.is_nullable() || right_field.is_nullable();
1078+
let data_type =
1079+
comparison_eq_coercion(left_field.data_type(), right_field.data_type())
1080+
.ok_or_else(|| {
1081+
DataFusionError::Plan(format!(
1082+
"UNION Column {} (type: {}) is not compatible with column {} (type: {})",
1083+
right_field.name(),
1084+
right_field.data_type(),
1085+
left_field.name(),
1086+
left_field.data_type()
1087+
))
1088+
})?;
1089+
1090+
Ok(DFField::new(
1091+
alias.as_deref(),
1092+
left_field.name(),
1093+
data_type,
1094+
nullable,
1095+
))
1096+
})
1097+
.collect::<Result<Vec<_>>>()?
1098+
.to_dfschema()?;
1099+
1100+
let inputs = vec![left_plan, right_plan]
10741101
.into_iter()
10751102
.flat_map(|p| match p {
10761103
LogicalPlan::Union(Union { inputs, .. }) => inputs,
10771104
x => vec![x],
1078-
});
1079-
1080-
inputs_iter
1081-
.clone()
1082-
.skip(1)
1083-
.try_for_each(|input_plan| -> Result<()> {
1084-
union_schema.check_arrow_schema_type_compatible(
1085-
&((**input_plan.schema()).clone().into()),
1086-
)
1087-
})?;
1088-
1089-
let inputs = inputs_iter
1090-
.map(|p| match p {
1091-
LogicalPlan::Projection(Projection {
1092-
expr, input, alias, ..
1093-
}) => {
1094-
project_with_column_index_alias(expr, input, union_schema.clone(), alias)
1095-
.unwrap()
1105+
})
1106+
.map(|p| {
1107+
let plan = coerce_plan_expr_for_schema(&p, &union_schema)?;
1108+
match plan {
1109+
LogicalPlan::Projection(Projection {
1110+
expr, input, alias, ..
1111+
}) => Ok(project_with_column_index_alias(
1112+
expr.to_vec(),
1113+
input,
1114+
Arc::new(union_schema.clone()),
1115+
alias,
1116+
)?),
1117+
x => Ok(x),
10961118
}
1097-
x => x,
10981119
})
1099-
.collect::<Vec<_>>();
1120+
.collect::<Result<Vec<_>>>()?;
1121+
11001122
if inputs.is_empty() {
11011123
return Err(DataFusionError::Plan("Empty UNION".to_string()));
11021124
}
11031125

1104-
let union_schema = (**inputs[0].schema()).clone();
11051126
let union_schema = Arc::new(match alias {
11061127
Some(ref alias) => union_schema.replace_qualifier(alias.as_str()),
11071128
None => union_schema.strip_qualifiers(),

datafusion/core/src/logical_plan/expr_rewriter.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
//! Expression rewriter
1919
20-
use super::Expr;
20+
use super::{Expr, ExprSchemable};
2121
use crate::logical_plan::plan::{Aggregate, Projection};
2222
use crate::logical_plan::DFSchema;
2323
use crate::logical_plan::LogicalPlan;
24+
use crate::optimizer::utils::from_plan;
2425
use crate::sql::utils::{
2526
extract_aliased_expr_names, rebase_expr, resolve_exprs_to_aliases,
2627
};
@@ -491,6 +492,40 @@ pub fn rewrite_udtfs_to_columns(exprs: Vec<Expr>, schema: DFSchema) -> Vec<Expr>
491492
.collect::<Vec<_>>()
492493
}
493494

495+
/// Returns plan with expressions coerced to types compatible with
496+
/// schema types
497+
pub fn coerce_plan_expr_for_schema(
498+
plan: &LogicalPlan,
499+
schema: &DFSchema,
500+
) -> Result<LogicalPlan> {
501+
let new_expr = plan
502+
.expressions()
503+
.into_iter()
504+
.enumerate()
505+
.map(|(i, expr)| {
506+
let new_type = schema.field(i).data_type();
507+
if plan.schema().field(i).data_type() != schema.field(i).data_type() {
508+
match (plan, &expr) {
509+
(
510+
LogicalPlan::Projection(Projection { input, .. }),
511+
Expr::Alias(e, alias),
512+
) => Ok(Expr::Alias(
513+
Box::new(e.clone().cast_to(new_type, input.schema())?),
514+
alias.clone(),
515+
)),
516+
_ => expr.cast_to(new_type, plan.schema()),
517+
}
518+
} else {
519+
Ok(expr)
520+
}
521+
})
522+
.collect::<Result<Vec<_>>>()?;
523+
524+
let new_inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
525+
526+
from_plan(plan, &new_expr, &new_inputs)
527+
}
528+
494529
#[cfg(test)]
495530
mod test {
496531
use super::*;

datafusion/core/src/physical_plan/planner.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,8 +1687,6 @@ mod tests {
16871687
col("c1").and(col("c1")),
16881688
// u8 AND u8
16891689
col("c3").and(col("c3")),
1690-
// utf8 = u32
1691-
col("c1").eq(col("c2")),
16921690
// u32 AND bool
16931691
col("c2").and(bool_expr),
16941692
// utf8 LIKE u32

datafusion/core/src/physical_plan/union.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@
2323
2424
use std::{any::Any, sync::Arc};
2525

26-
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
26+
use arrow::{
27+
datatypes::{Field, Schema, SchemaRef},
28+
record_batch::RecordBatch,
29+
};
2730
use futures::StreamExt;
31+
use itertools::Itertools;
2832

2933
use super::{
3034
expressions::PhysicalSortExpr,
@@ -46,14 +50,38 @@ pub struct UnionExec {
4650
inputs: Vec<Arc<dyn ExecutionPlan>>,
4751
/// Execution metrics
4852
metrics: ExecutionPlanMetricsSet,
53+
/// Schema of Union
54+
schema: SchemaRef,
4955
}
5056

5157
impl UnionExec {
5258
/// Create a new UnionExec
5359
pub fn new(inputs: Vec<Arc<dyn ExecutionPlan>>) -> Self {
60+
let fields: Vec<Field> = (0..inputs[0].schema().fields().len())
61+
.map(|i| {
62+
inputs
63+
.iter()
64+
.filter_map(|input| {
65+
if input.schema().fields().len() > i {
66+
Some(input.schema().field(i).clone())
67+
} else {
68+
None
69+
}
70+
})
71+
.find_or_first(|f| f.is_nullable())
72+
.unwrap()
73+
})
74+
.collect();
75+
76+
let schema = Arc::new(Schema::new_with_metadata(
77+
fields,
78+
inputs[0].schema().metadata().clone(),
79+
));
80+
5481
UnionExec {
5582
inputs,
5683
metrics: ExecutionPlanMetricsSet::new(),
84+
schema,
5785
}
5886
}
5987

@@ -71,7 +99,7 @@ impl ExecutionPlan for UnionExec {
7199
}
72100

73101
fn schema(&self) -> SchemaRef {
74-
self.inputs[0].schema()
102+
self.schema.clone()
75103
}
76104

77105
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {

0 commit comments

Comments
 (0)