Skip to content

Commit 6118bb7

Browse files
committed
Rewrite FlatMapElimination to simplify and generalize it
From Petros
1 parent 80f3614 commit 6118bb7

File tree

3 files changed

+77
-68
lines changed

3 files changed

+77
-68
lines changed

src/expr/src/relation/func.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3792,19 +3792,16 @@ pub enum TableFunc {
37923792
WithOrdinality(WithOrdinality),
37933793
}
37943794

3795-
/// Private enum variant of `TableFunc`. Don't construct this directly, but use
3796-
/// `TableFunc::with_ordinality` instead.
3797-
///
37983795
/// Evaluates the inner table function, expands its results into unary (repeating each row as
37993796
/// many times as the diff indicates), and appends an integer corresponding to the ordinal
38003797
/// position (starting from 1). For example, it numbers the elements of a list when calling
38013798
/// `unnest_list`.
38023799
///
3803-
/// TODO(ggevay): This struct (and its field) is pub only temporarily, until we make
3804-
/// `FlatMapElimination` not dive into it.
3800+
/// Private enum variant of `TableFunc`. Don't construct this directly, but use
3801+
/// `TableFunc::with_ordinality` instead.
38053802
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3806-
pub struct WithOrdinality {
3807-
pub inner: Box<TableFunc>,
3803+
struct WithOrdinality {
3804+
inner: Box<TableFunc>,
38083805
}
38093806

38103807
impl TableFunc {

src/transform/src/canonicalization/flat_map_elimination.rs

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
// the Business Source License, use of this software will be governed
88
// by the Apache License, Version 2.0.
99

10-
//! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
11-
//!
10+
//! For a `FlatMap` whose args are all constants, turns it into `Map` if only 1 row is produced by
11+
//! the table function, or turns it into an empty constant if 0 rows are produced by the table
12+
//! function. Additionally, a `Wrap` whose width is larger than its number of arguments can be
13+
//! removed.
1214
15+
use itertools::Itertools;
1316
use mz_expr::visit::Visit;
1417
use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
15-
use mz_repr::{Datum, Diff, ScalarType};
18+
use mz_repr::{Diff, Row, RowArena};
1619

1720
use crate::TransformCtx;
1821

19-
/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
22+
/// See comment at the top of the file.
2023
#[derive(Debug)]
2124
pub struct FlatMapElimination;
2225

@@ -42,70 +45,49 @@ impl crate::Transform for FlatMapElimination {
4245
}
4346

4447
impl FlatMapElimination {
45-
/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
48+
/// See comment at the top of the file.
4649
pub fn action(relation: &mut MirRelationExpr) {
4750
if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48-
let (func, with_ordinality) = if let TableFunc::WithOrdinality(with_ordinality) = func {
49-
// get to the actual function, but remember that we have a WITH ORDINALITY clause.
50-
(&*with_ordinality.inner, true)
51-
} else {
52-
(&*func, false)
53-
};
54-
55-
if let TableFunc::GuardSubquerySize { .. } = func {
56-
// (`with_ordinality` doesn't matter because this function never emits rows)
57-
if let Some(1) = exprs[0].as_literal_int64() {
58-
relation.take_safely(None);
59-
}
60-
} else if let TableFunc::Wrap { width, .. } = func {
51+
// Treat Wrap specially.
52+
if let TableFunc::Wrap { width, .. } = func {
6153
if *width >= exprs.len() {
6254
*relation = input.take_dangerous().map(std::mem::take(exprs));
63-
if with_ordinality {
64-
*relation = relation.take_dangerous().map_one(MirScalarExpr::literal(
65-
Ok(Datum::Int64(1)),
66-
ScalarType::Int64,
67-
));
68-
}
55+
return;
56+
}
57+
}
58+
// For all other table functions, check for all arguments being literals.
59+
let mut args = vec![];
60+
for e in exprs {
61+
match e.as_literal() {
62+
Some(Ok(datum)) => args.push(datum),
63+
// Give up if any arg is not a literal, or if it's a literal error.
64+
_ => return,
6965
}
70-
} else if is_supported_unnest(func) {
71-
let func = func.clone();
72-
let exprs = exprs.clone();
73-
use mz_expr::MirScalarExpr;
74-
use mz_repr::RowArena;
75-
if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
76-
let temp_storage = RowArena::default();
77-
if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
78-
match (iter.next(), iter.next()) {
79-
(None, _) => {
80-
// If there are no elements in the literal argument, no output.
81-
relation.take_safely(None);
82-
}
83-
(Some((row, Diff::ONE)), None) => {
84-
assert_eq!(func.output_type().column_types.len(), 1);
85-
*relation =
86-
input.take_dangerous().map(vec![MirScalarExpr::Literal(
87-
Ok(row),
88-
func.output_type().column_types[0].clone(),
89-
)]);
90-
if with_ordinality {
91-
*relation =
92-
relation.take_dangerous().map_one(MirScalarExpr::literal(
93-
Ok(Datum::Int64(1)),
94-
ScalarType::Int64,
95-
));
96-
}
97-
}
98-
_ => {}
99-
}
100-
};
66+
}
67+
let temp_storage = RowArena::new();
68+
let (first, second) = match func.eval(&args, &temp_storage) {
69+
Ok(mut r) => (r.next(), r.next()),
70+
// don't play with errors
71+
Err(_) => return,
72+
};
73+
match (first, second) {
74+
// The table function evaluated to an empty collection.
75+
(None, None) => {
76+
relation.take_safely(None);
10177
}
78+
// The table function evaluated to a collection with exactly 1 row.
79+
(Some((first_row, Diff::ONE)), None) => {
80+
let types = func.output_type().column_types;
81+
let map_exprs = first_row
82+
.into_iter()
83+
.zip_eq(types)
84+
.map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
85+
.collect();
86+
*relation = input.take_dangerous().map(map_exprs);
87+
}
88+
// The table function evaluated to a collection with more than 1 row; nothing to do.
89+
_ => {}
10290
}
10391
}
10492
}
10593
}
106-
107-
/// Returns `true` for `unnest_~` variants supported by [`FlatMapElimination`].
108-
fn is_supported_unnest(func: &TableFunc) -> bool {
109-
use TableFunc::*;
110-
matches!(func, UnnestArray { .. } | UnnestList { .. })
111-
}

test/sqllogictest/table_func.slt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,3 +1710,33 @@ Source materialize.public.x
17101710
Target cluster: quickstart
17111711

17121712
EOF
1713+
1714+
query T multiline
1715+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1716+
SELECT *
1717+
FROM x, generate_series(1, x.a)
1718+
WHERE x.a = 1
1719+
----
1720+
Explained Query:
1721+
Project (#0, #1, #0)
1722+
Filter (#0{a} = 1)
1723+
ReadStorage materialize.public.x
1724+
1725+
Source materialize.public.x
1726+
1727+
Target cluster: quickstart
1728+
1729+
EOF
1730+
1731+
query T multiline
1732+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1733+
SELECT *
1734+
FROM x, generate_series(5, x.a)
1735+
WHERE x.a = 1
1736+
----
1737+
Explained Query:
1738+
Constant <empty>
1739+
1740+
Target cluster: quickstart
1741+
1742+
EOF

0 commit comments

Comments
 (0)