Skip to content

Commit 479807e

Browse files
authored
Merge pull request #33276 from ggevay/flatmap-elim-refactor
Rewrite `FlatMapElimination`
2 parents 7611585 + bc34d85 commit 479807e

File tree

6 files changed

+140
-74
lines changed

6 files changed

+140
-74
lines changed

src/expr/src/relation/func.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3792,19 +3792,16 @@ pub enum TableFunc {
37923792
WithOrdinality(WithOrdinality),
37933793
}
37943794

3795-
/// Private enum variant of `TableFunc`. Don't construct this directly, but use
3796-
/// `TableFunc::with_ordinality` instead.
3797-
///
37983795
/// Evaluates the inner table function, expands its results into unary (repeating each row as
37993796
/// many times as the diff indicates), and appends an integer corresponding to the ordinal
38003797
/// position (starting from 1). For example, it numbers the elements of a list when calling
38013798
/// `unnest_list`.
38023799
///
3803-
/// TODO(ggevay): This struct (and its field) is pub only temporarily, until we make
3804-
/// `FlatMapElimination` not dive into it.
3800+
/// Private enum variant of `TableFunc`. Don't construct this directly, but use
3801+
/// `TableFunc::with_ordinality` instead.
38053802
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3806-
pub struct WithOrdinality {
3807-
pub inner: Box<TableFunc>,
3803+
struct WithOrdinality {
3804+
inner: Box<TableFunc>,
38083805
}
38093806

38103807
impl TableFunc {

src/expr/src/scalar.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ impl MirScalarExpr {
648648
mem::replace(self, MirScalarExpr::literal_null(ScalarType::String))
649649
}
650650

651+
/// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
652+
/// Otherwise, it returns None.
651653
pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
652654
if let MirScalarExpr::Literal(lit, _column_type) = self {
653655
Some(lit.as_ref().map(|row| row.unpack_first()))
@@ -656,6 +658,12 @@ impl MirScalarExpr {
656658
}
657659
}
658660

661+
/// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
662+
/// literal's Datum only if the expression is a literal, and it's not a literal error.
663+
pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
664+
self.as_literal().map(|eval_err| eval_err.ok()).flatten()
665+
}
666+
659667
pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
660668
if let MirScalarExpr::Literal(lit, _column_type) = self {
661669
Some(lit.clone())

src/transform/src/canonicalization/flat_map_elimination.rs

Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,22 @@
77
// the Business Source License, use of this software will be governed
88
// by the Apache License, Version 2.0.
99

10-
//! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
10+
//! For a `FlatMap` where the table function's arguments are all constants, turns it into `Map` if
11+
//! only 1 row is produced by the table function, or turns it into an empty constant collection if 0
12+
//! rows are produced by the table function.
1113
//!
14+
//! It does an additional optimization on the `Wrap` table function: when `Wrap`'s width is larger
15+
//! than its number of arguments, it removes the `FlatMap Wrap ...`, because such `Wrap`s would have
16+
//! no effect.
1217
18+
use itertools::Itertools;
1319
use mz_expr::visit::Visit;
1420
use mz_expr::{MirRelationExpr, MirScalarExpr, TableFunc};
15-
use mz_repr::{Datum, Diff, ScalarType};
21+
use mz_repr::{Diff, Row, RowArena};
1622

1723
use crate::TransformCtx;
1824

19-
/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
25+
/// Attempts to eliminate FlatMaps that are sure to have 0 or 1 results on each input row.
2026
#[derive(Debug)]
2127
pub struct FlatMapElimination;
2228

@@ -42,70 +48,56 @@ impl crate::Transform for FlatMapElimination {
4248
}
4349

4450
impl FlatMapElimination {
45-
/// Turns `FlatMap` into `Map` if only one row is produced by flatmap.
51+
/// Apply `FlatMapElimination` to the root of the given `MirRelationExpr`.
4652
pub fn action(relation: &mut MirRelationExpr) {
53+
// Treat Wrap specially: we can sometimes optimize it out even when it has non-literal
54+
// arguments.
55+
//
56+
// (No need to look for WithOrdinality here, as that never occurs with Wrap: users can't
57+
// call Wrap directly; we only create calls to Wrap ourselves, and we don't use
58+
// WithOrdinality on it.)
4759
if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
48-
let (func, with_ordinality) = if let TableFunc::WithOrdinality(with_ordinality) = func {
49-
// get to the actual function, but remember that we have a WITH ORDINALITY clause.
50-
(&*with_ordinality.inner, true)
51-
} else {
52-
(&*func, false)
53-
};
54-
55-
if let TableFunc::GuardSubquerySize { .. } = func {
56-
// (`with_ordinality` doesn't matter because this function never emits rows)
57-
if let Some(1) = exprs[0].as_literal_int64() {
58-
relation.take_safely(None);
59-
}
60-
} else if let TableFunc::Wrap { width, .. } = func {
60+
if let TableFunc::Wrap { width, .. } = func {
6161
if *width >= exprs.len() {
6262
*relation = input.take_dangerous().map(std::mem::take(exprs));
63-
if with_ordinality {
64-
*relation = relation.take_dangerous().map_one(MirScalarExpr::literal(
65-
Ok(Datum::Int64(1)),
66-
ScalarType::Int64,
67-
));
68-
}
6963
}
70-
} else if is_supported_unnest(func) {
71-
let func = func.clone();
72-
let exprs = exprs.clone();
73-
use mz_expr::MirScalarExpr;
74-
use mz_repr::RowArena;
75-
if let MirScalarExpr::Literal(Ok(row), ..) = &exprs[0] {
76-
let temp_storage = RowArena::default();
77-
if let Ok(mut iter) = func.eval(&[row.iter().next().unwrap()], &temp_storage) {
78-
match (iter.next(), iter.next()) {
79-
(None, _) => {
80-
// If there are no elements in the literal argument, no output.
81-
relation.take_safely(None);
82-
}
83-
(Some((row, Diff::ONE)), None) => {
84-
assert_eq!(func.output_type().column_types.len(), 1);
85-
*relation =
86-
input.take_dangerous().map(vec![MirScalarExpr::Literal(
87-
Ok(row),
88-
func.output_type().column_types[0].clone(),
89-
)]);
90-
if with_ordinality {
91-
*relation =
92-
relation.take_dangerous().map_one(MirScalarExpr::literal(
93-
Ok(Datum::Int64(1)),
94-
ScalarType::Int64,
95-
));
96-
}
97-
}
98-
_ => {}
99-
}
100-
};
64+
}
65+
}
66+
// For all other table functions (and Wraps that are not covered by the above), check
67+
// whether all arguments are literals (with no errors), in which case we'll evaluate the
68+
// table function and check how many output rows it has, and maybe turn the FlatMap into
69+
// something simpler.
70+
if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
71+
if let Some(args) = exprs
72+
.iter()
73+
.map(|e| e.as_literal_non_error())
74+
.collect::<Option<Vec<_>>>()
75+
{
76+
let temp_storage = RowArena::new();
77+
let (first, second) = match func.eval(&args, &temp_storage) {
78+
Ok(mut r) => (r.next(), r.next()),
79+
// don't play with errors
80+
Err(_) => return,
81+
};
82+
match (first, second) {
83+
// The table function evaluated to an empty collection.
84+
(None, _) => {
85+
relation.take_safely(None);
86+
}
87+
// The table function evaluated to a collection with exactly 1 row.
88+
(Some((first_row, Diff::ONE)), None) => {
89+
let types = func.output_type().column_types;
90+
let map_exprs = first_row
91+
.into_iter()
92+
.zip_eq(types)
93+
.map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
94+
.collect();
95+
*relation = input.take_dangerous().map(map_exprs);
96+
}
97+
// The table function evaluated to a collection with more than 1 row; nothing to do.
98+
_ => {}
10199
}
102100
}
103101
}
104102
}
105103
}
106-
107-
/// Returns `true` for `unnest_~` variants supported by [`FlatMapElimination`].
108-
fn is_supported_unnest(func: &TableFunc) -> bool {
109-
use TableFunc::*;
110-
matches!(func, UnnestArray { .. } | UnnestList { .. })
111-
}

src/transform/tests/test_transforms/flatmap_to_map.spec renamed to src/transform/tests/test_transforms/flat_map_elimination.spec

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ FlatMap wrap3(0, 1, 2, 3)
8383
## Support for unnest_~ calls
8484
## --------------------------
8585

86-
# Rewrite possible for `unnset_array`
86+
# Rewrite possible for `unnest_array`
8787
# Example SQL: select unnest(array[f1]) from t1 where f1 = 5;
8888
apply pipeline=flat_map_elimination
8989
FlatMap unnest_array({5})
@@ -101,15 +101,15 @@ FlatMap unnest_list([5])
101101
Map (5)
102102
Get t0
103103

104-
# Rewrite not possible: unnest_array(-) argument is not resuced
104+
# Rewrite not possible: unnest_array(-) argument is not reduced to a literal
105105
apply pipeline=flat_map_elimination
106106
FlatMap unnest_array(array[5])
107107
Get t0
108108
----
109109
FlatMap unnest_array(array[5])
110110
Get t0
111111

112-
# Rewrite not possible: unnest_list(-) argument is not resuced
112+
# Rewrite not possible: unnest_list(-) argument is not reduced to a literal
113113
apply pipeline=flat_map_elimination
114114
FlatMap unnest_list(list[5])
115115
Get t0
@@ -119,16 +119,37 @@ FlatMap unnest_list(list[5])
119119

120120
# Rewrite not possible: unnest_array(-) argument is not a singleton
121121
apply pipeline=flat_map_elimination
122+
FlatMap unnest_array({5, 6})
123+
Get t0
124+
----
125+
FlatMap unnest_array({5, 6})
126+
Get t0
127+
128+
# Rewrite not possible: unnest_list(-) argument is not a singleton
129+
apply pipeline=flat_map_elimination
122130
FlatMap unnest_list([5, 6])
123131
Get t0
124132
----
125133
FlatMap unnest_list([5, 6])
126134
Get t0
127135

128-
# Rewrite not possible: unnest_list(-) argument is not a singleton
136+
# generate_series can produce 0, 1, or more rows, based on its arguments
129137
apply pipeline=flat_map_elimination
130-
FlatMap unnest_list(list[5])
138+
FlatMap generate_series(5, 2, 1)
131139
Get t0
132140
----
133-
FlatMap unnest_list(list[5])
141+
Constant <empty>
142+
143+
apply pipeline=flat_map_elimination
144+
FlatMap generate_series(5, 5, 1)
145+
Get t0
146+
----
147+
Map (5)
148+
Get t0
149+
150+
apply pipeline=flat_map_elimination
151+
FlatMap generate_series(5, 6, 1)
152+
Get t0
153+
----
154+
FlatMap generate_series(5, 6, 1)
134155
Get t0

test/sqllogictest/table_func.slt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,3 +1710,50 @@ Source materialize.public.x
17101710
Target cluster: quickstart
17111711

17121712
EOF
1713+
1714+
query T multiline
1715+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1716+
SELECT *
1717+
FROM x, generate_series(1, x.a)
1718+
WHERE x.a = 1;
1719+
----
1720+
Explained Query:
1721+
Project (#0, #1, #0)
1722+
Filter (#0{a} = 1)
1723+
ReadStorage materialize.public.x
1724+
1725+
Source materialize.public.x
1726+
1727+
Target cluster: quickstart
1728+
1729+
EOF
1730+
1731+
query T multiline
1732+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1733+
SELECT *
1734+
FROM x, generate_series(5, x.a)
1735+
WHERE x.a = 1;
1736+
----
1737+
Explained Query:
1738+
Constant <empty>
1739+
1740+
Target cluster: quickstart
1741+
1742+
EOF
1743+
1744+
query T multiline
1745+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1746+
SELECT *
1747+
FROM x, LATERAL (VALUES ((1), (x.a)))
1748+
WHERE x.a = 1;
1749+
----
1750+
Explained Query:
1751+
Project (#0, #1, #0, #0)
1752+
Filter (#0{a} = 1)
1753+
ReadStorage materialize.public.x
1754+
1755+
Source materialize.public.x
1756+
1757+
Target cluster: quickstart
1758+
1759+
EOF

test/sqllogictest/window_funcs.slt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6874,7 +6874,8 @@ NULL 11 NULL 22 22 22 {22}
68746874
# than one window function call in one test, because we currently forget the key information after a window function
68756875
# call (even when `ReduceElision` simplifies the window function call).
68766876
# TODO: Add an optimization that eliminates a Map-FlatMap pair where the Map is just creating a 1-element list on which
6877-
# the FlatMap is immediately calling `unnest_list`.
6877+
# the FlatMap is immediately calling `unnest_list`. We could use the `Equivalences` analysis for this, which would tell
6878+
# us that the column reference in `unnest_list` is equal to a `list_create` with 1 argument.
68786879
query T multiline
68796880
EXPLAIN OPTIMIZED PLAN WITH(keys, humanized expressions) AS VERBOSE TEXT FOR
68806881
SELECT

0 commit comments

Comments
 (0)