Skip to content

Commit 08e8ef5

Browse files
Merge branch 'main' into feature/lazy-partitioned-hash-join
2 parents 420e207 + 79869a7 commit 08e8ef5

File tree

31 files changed

+950
-71
lines changed

31 files changed

+950
-71
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/catalog-listing/src/helpers.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
8383
| Expr::Exists(_)
8484
| Expr::InSubquery(_)
8585
| Expr::ScalarSubquery(_)
86+
| Expr::SetComparison(_)
8687
| Expr::GroupingSet(_)
8788
| Expr::Case(_) => Ok(TreeNodeRecursion::Continue),
8889

datafusion/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ insta = { workspace = true }
177177
paste = { workspace = true }
178178
rand = { workspace = true, features = ["small_rng"] }
179179
rand_distr = "0.5"
180+
recursive = { workspace = true }
180181
regex = { workspace = true }
181182
rstest = { workspace = true }
182183
serde_json = { workspace = true }
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Int32Array, StringArray};
21+
use arrow::datatypes::{DataType, Field, Schema};
22+
use arrow::record_batch::RecordBatch;
23+
use datafusion::prelude::SessionContext;
24+
use datafusion_common::{Result, assert_batches_eq, assert_contains};
25+
26+
fn build_table(values: &[i32]) -> Result<RecordBatch> {
27+
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
28+
let array =
29+
Arc::new(Int32Array::from(values.to_vec())) as Arc<dyn arrow::array::Array>;
30+
RecordBatch::try_new(schema, vec![array]).map_err(Into::into)
31+
}
32+
33+
#[tokio::test]
34+
async fn set_comparison_any() -> Result<()> {
35+
let ctx = SessionContext::new();
36+
37+
ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
38+
// Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly.
39+
ctx.register_batch("s", {
40+
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
41+
let array = Arc::new(Int32Array::from(vec![Some(5), None]))
42+
as Arc<dyn arrow::array::Array>;
43+
RecordBatch::try_new(schema, vec![array])?
44+
})?;
45+
46+
let df = ctx
47+
.sql("select v from t where v > any(select v from s)")
48+
.await?;
49+
let results = df.collect().await?;
50+
51+
assert_batches_eq!(
52+
&["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",],
53+
&results
54+
);
55+
Ok(())
56+
}
57+
58+
#[tokio::test]
59+
async fn set_comparison_any_aggregate_subquery() -> Result<()> {
60+
let ctx = SessionContext::new();
61+
62+
ctx.register_batch("t", build_table(&[1, 7])?)?;
63+
ctx.register_batch("s", build_table(&[1, 2, 3])?)?;
64+
65+
let df = ctx
66+
.sql(
67+
"select v from t where v > any(select sum(v) from s group by v % 2) order by v",
68+
)
69+
.await?;
70+
let results = df.collect().await?;
71+
72+
assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results);
73+
Ok(())
74+
}
75+
76+
#[tokio::test]
77+
async fn set_comparison_all_empty() -> Result<()> {
78+
let ctx = SessionContext::new();
79+
80+
ctx.register_batch("t", build_table(&[1, 6, 10])?)?;
81+
ctx.register_batch(
82+
"e",
83+
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
84+
"v",
85+
DataType::Int32,
86+
true,
87+
)]))),
88+
)?;
89+
90+
let df = ctx
91+
.sql("select v from t where v < all(select v from e)")
92+
.await?;
93+
let results = df.collect().await?;
94+
95+
assert_batches_eq!(
96+
&[
97+
"+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+",
98+
],
99+
&results
100+
);
101+
Ok(())
102+
}
103+
104+
#[tokio::test]
105+
async fn set_comparison_type_mismatch() -> Result<()> {
106+
let ctx = SessionContext::new();
107+
108+
ctx.register_batch("t", build_table(&[1])?)?;
109+
ctx.register_batch("strings", {
110+
let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
111+
let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")]))
112+
as Arc<dyn arrow::array::Array>;
113+
RecordBatch::try_new(schema, vec![array])?
114+
})?;
115+
116+
let df = ctx
117+
.sql("select v from t where v > any(select s from strings)")
118+
.await?;
119+
let err = df.collect().await.unwrap_err();
120+
assert_contains!(
121+
err.to_string(),
122+
"expr type Int32 can't cast to Utf8 in SetComparison"
123+
);
124+
Ok(())
125+
}
126+
127+
#[tokio::test]
128+
async fn set_comparison_multiple_operators() -> Result<()> {
129+
let ctx = SessionContext::new();
130+
131+
ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?;
132+
ctx.register_batch("s", build_table(&[2, 3])?)?;
133+
134+
let df = ctx
135+
.sql("select v from t where v = any(select v from s) order by v")
136+
.await?;
137+
let results = df.collect().await?;
138+
assert_batches_eq!(
139+
&["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",],
140+
&results
141+
);
142+
143+
let df = ctx
144+
.sql("select v from t where v != all(select v from s) order by v")
145+
.await?;
146+
let results = df.collect().await?;
147+
assert_batches_eq!(
148+
&["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",],
149+
&results
150+
);
151+
152+
let df = ctx
153+
.sql("select v from t where v >= all(select v from s) order by v")
154+
.await?;
155+
let results = df.collect().await?;
156+
assert_batches_eq!(
157+
&["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",],
158+
&results
159+
);
160+
161+
let df = ctx
162+
.sql("select v from t where v <= any(select v from s) order by v")
163+
.await?;
164+
let results = df.collect().await?;
165+
assert_batches_eq!(
166+
&[
167+
"+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+",
168+
],
169+
&results
170+
);
171+
Ok(())
172+
}
173+
174+
#[tokio::test]
175+
async fn set_comparison_null_semantics_all() -> Result<()> {
176+
let ctx = SessionContext::new();
177+
178+
ctx.register_batch("t", build_table(&[5])?)?;
179+
ctx.register_batch("s", {
180+
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)]));
181+
let array = Arc::new(Int32Array::from(vec![Some(1), None]))
182+
as Arc<dyn arrow::array::Array>;
183+
RecordBatch::try_new(schema, vec![array])?
184+
})?;
185+
186+
let df = ctx
187+
.sql("select v from t where v != all(select v from s)")
188+
.await?;
189+
let results = df.collect().await?;
190+
let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum();
191+
assert_eq!(0, row_count);
192+
Ok(())
193+
}

datafusion/core/tests/sql/unparser.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_physical_plan::ExecutionPlanProperties;
4747
use datafusion_sql::unparser::Unparser;
4848
use datafusion_sql::unparser::dialect::DefaultDialect;
4949
use itertools::Itertools;
50+
use recursive::{set_minimum_stack_size, set_stack_allocation_size};
5051

5152
/// Paths to benchmark query files (supports running from repo root or different working directories).
5253
const BENCHMARK_PATHS: &[&str] = &["../../benchmarks/", "./benchmarks/"];
@@ -458,5 +459,8 @@ async fn test_clickbench_unparser_roundtrip() {
458459

459460
#[tokio::test]
460461
async fn test_tpch_unparser_roundtrip() {
462+
// Grow stacker segments earlier to avoid deep unparser recursion overflow in q20.
463+
set_minimum_stack_size(512 * 1024);
464+
set_stack_allocation_size(8 * 1024 * 1024);
461465
run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await;
462466
}

datafusion/expr/src/expr.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ pub enum Expr {
372372
Exists(Exists),
373373
/// IN subquery
374374
InSubquery(InSubquery),
375+
/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
376+
SetComparison(SetComparison),
375377
/// Scalar subquery
376378
ScalarSubquery(Subquery),
377379
/// Represents a reference to all available fields in a specific schema,
@@ -1101,6 +1103,54 @@ impl Exists {
11011103
}
11021104
}
11031105

1106+
/// Whether the set comparison uses `ANY`/`SOME` or `ALL`
1107+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)]
1108+
pub enum SetQuantifier {
1109+
/// `ANY` (or `SOME`)
1110+
Any,
1111+
/// `ALL`
1112+
All,
1113+
}
1114+
1115+
impl Display for SetQuantifier {
1116+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1117+
match self {
1118+
SetQuantifier::Any => write!(f, "ANY"),
1119+
SetQuantifier::All => write!(f, "ALL"),
1120+
}
1121+
}
1122+
}
1123+
1124+
/// Set comparison subquery (e.g. `= ANY`, `> ALL`)
1125+
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
1126+
pub struct SetComparison {
1127+
/// The expression to compare
1128+
pub expr: Box<Expr>,
1129+
/// Subquery that will produce a single column of data to compare against
1130+
pub subquery: Subquery,
1131+
/// Comparison operator (e.g. `=`, `>`, `<`)
1132+
pub op: Operator,
1133+
/// Quantifier (`ANY`/`ALL`)
1134+
pub quantifier: SetQuantifier,
1135+
}
1136+
1137+
impl SetComparison {
1138+
/// Create a new set comparison expression
1139+
pub fn new(
1140+
expr: Box<Expr>,
1141+
subquery: Subquery,
1142+
op: Operator,
1143+
quantifier: SetQuantifier,
1144+
) -> Self {
1145+
Self {
1146+
expr,
1147+
subquery,
1148+
op,
1149+
quantifier,
1150+
}
1151+
}
1152+
}
1153+
11041154
/// InList expression
11051155
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
11061156
pub struct InList {
@@ -1502,6 +1552,7 @@ impl Expr {
15021552
Expr::GroupingSet(..) => "GroupingSet",
15031553
Expr::InList { .. } => "InList",
15041554
Expr::InSubquery(..) => "InSubquery",
1555+
Expr::SetComparison(..) => "SetComparison",
15051556
Expr::IsNotNull(..) => "IsNotNull",
15061557
Expr::IsNull(..) => "IsNull",
15071558
Expr::Like { .. } => "Like",
@@ -2057,6 +2108,7 @@ impl Expr {
20572108
| Expr::GroupingSet(..)
20582109
| Expr::InList(..)
20592110
| Expr::InSubquery(..)
2111+
| Expr::SetComparison(..)
20602112
| Expr::IsFalse(..)
20612113
| Expr::IsNotFalse(..)
20622114
| Expr::IsNotNull(..)
@@ -2650,6 +2702,16 @@ impl HashNode for Expr {
26502702
subquery.hash(state);
26512703
negated.hash(state);
26522704
}
2705+
Expr::SetComparison(SetComparison {
2706+
expr: _,
2707+
subquery,
2708+
op,
2709+
quantifier,
2710+
}) => {
2711+
subquery.hash(state);
2712+
op.hash(state);
2713+
quantifier.hash(state);
2714+
}
26532715
Expr::ScalarSubquery(subquery) => {
26542716
subquery.hash(state);
26552717
}
@@ -2840,6 +2902,12 @@ impl Display for SchemaDisplay<'_> {
28402902
write!(f, "NOT IN")
28412903
}
28422904
Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"),
2905+
Expr::SetComparison(SetComparison {
2906+
expr,
2907+
op,
2908+
quantifier,
2909+
..
2910+
}) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())),
28432911
Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)),
28442912
Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)),
28452913
Expr::IsNotTrue(expr) => {
@@ -3315,6 +3383,12 @@ impl Display for Expr {
33153383
subquery,
33163384
negated: false,
33173385
}) => write!(f, "{expr} IN ({subquery:?})"),
3386+
Expr::SetComparison(SetComparison {
3387+
expr,
3388+
subquery,
3389+
op,
3390+
quantifier,
3391+
}) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"),
33183392
Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"),
33193393
Expr::BinaryExpr(expr) => write!(f, "{expr}"),
33203394
Expr::ScalarFunction(fun) => {
@@ -3798,6 +3872,7 @@ mod test {
37983872
}
37993873

38003874
use super::*;
3875+
use crate::logical_plan::{EmptyRelation, LogicalPlan};
38013876

38023877
#[test]
38033878
fn test_display_wildcard() {
@@ -3888,6 +3963,28 @@ mod test {
38883963
)
38893964
}
38903965

3966+
#[test]
3967+
fn test_display_set_comparison() {
3968+
let subquery = Subquery {
3969+
subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
3970+
produce_one_row: false,
3971+
schema: Arc::new(DFSchema::empty()),
3972+
})),
3973+
outer_ref_columns: vec![],
3974+
spans: Spans::new(),
3975+
};
3976+
3977+
let expr = Expr::SetComparison(SetComparison::new(
3978+
Box::new(Expr::Column(Column::from_name("a"))),
3979+
subquery,
3980+
Operator::Gt,
3981+
SetQuantifier::Any,
3982+
));
3983+
3984+
assert_eq!(format!("{expr}"), "a > ANY (<subquery>)");
3985+
assert_eq!(format!("{}", expr.human_display()), "a > ANY (<subquery>)");
3986+
}
3987+
38913988
#[test]
38923989
fn test_schema_display_alias_with_relation() {
38933990
assert_eq!(

0 commit comments

Comments
 (0)