Skip to content

Commit bc34d85

Browse files
committed
Add comments and do a bit of refactoring in FlatMapElimination
1 parent 6118bb7 commit bc34d85

File tree

4 files changed

+77
-42
lines changed

4 files changed

+77
-42
lines changed

src/expr/src/scalar.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ impl MirScalarExpr {
648648
mem::replace(self, MirScalarExpr::literal_null(ScalarType::String))
649649
}
650650

651+
/// If the expression is a literal, this returns the literal's Datum or the literal's EvalError.
652+
/// Otherwise, it returns None.
651653
pub fn as_literal(&self) -> Option<Result<Datum<'_>, &EvalError>> {
652654
if let MirScalarExpr::Literal(lit, _column_type) = self {
653655
Some(lit.as_ref().map(|row| row.unpack_first()))
@@ -656,6 +658,12 @@ impl MirScalarExpr {
656658
}
657659
}
658660

661+
/// Flattens the two failure modes of `as_literal` into one layer of Option: returns the
662+
/// literal's Datum only if the expression is a literal, and it's not a literal error.
663+
pub fn as_literal_non_error(&self) -> Option<Datum<'_>> {
664+
self.as_literal().map(|eval_err| eval_err.ok()).flatten()
665+
}
666+
659667
pub fn as_literal_owned(&self) -> Option<Result<Row, EvalError>> {
660668
if let MirScalarExpr::Literal(lit, _column_type) = self {
661669
Some(lit.clone())

src/transform/src/canonicalization/flat_map_elimination.rs

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
// the Business Source License, use of this software will be governed
88
// by the Apache License, Version 2.0.
99

10-
//! For a `FlatMap` whose args are all constants, turns it into `Map` if only 1 row is produced by
11-
//! the table function, or turns it into an empty constant if 0 rows are produced by the table
12-
//! function. Additionally, a `Wrap` whose width is larger than its number of arguments can be
13-
//! removed.
10+
//! For a `FlatMap` where the table function's arguments are all constants, turns it into `Map` if
11+
//! only 1 row is produced by the table function, or turns it into an empty constant collection if 0
12+
//! rows are produced by the table function.
13+
//!
14+
//! It does an additional optimization on the `Wrap` table function: when `Wrap`'s width is larger
15+
//! than its number of arguments, it removes the `FlatMap Wrap ...`, because such `Wrap`s would have
16+
//! no effect.
1417
1518
use itertools::Itertools;
1619
use mz_expr::visit::Visit;
@@ -19,7 +22,7 @@ use mz_repr::{Diff, Row, RowArena};
1922

2023
use crate::TransformCtx;
2124

22-
/// See comment at the top of the file.
25+
/// Attempts to eliminate FlatMaps that are sure to have 0 or 1 results on each input row.
2326
#[derive(Debug)]
2427
pub struct FlatMapElimination;
2528

@@ -45,48 +48,55 @@ impl crate::Transform for FlatMapElimination {
4548
}
4649

4750
impl FlatMapElimination {
48-
/// See comment at the top of the file.
51+
/// Apply `FlatMapElimination` to the root of the given `MirRelationExpr`.
4952
pub fn action(relation: &mut MirRelationExpr) {
53+
// Treat Wrap specially: we can sometimes optimize it out even when it has non-literal
54+
// arguments.
55+
//
56+
// (No need to look for WithOrdinality here, as that never occurs with Wrap: users can't
57+
// call Wrap directly; we only create calls to Wrap ourselves, and we don't use
58+
// WithOrdinality on it.)
5059
if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
51-
// Treat Wrap specially.
5260
if let TableFunc::Wrap { width, .. } = func {
5361
if *width >= exprs.len() {
5462
*relation = input.take_dangerous().map(std::mem::take(exprs));
55-
return;
5663
}
5764
}
58-
// For all other table functions, check for all arguments being literals.
59-
let mut args = vec![];
60-
for e in exprs {
61-
match e.as_literal() {
62-
Some(Ok(datum)) => args.push(datum),
63-
// Give up if any arg is not a literal, or if it's a literal error.
64-
_ => return,
65-
}
66-
}
67-
let temp_storage = RowArena::new();
68-
let (first, second) = match func.eval(&args, &temp_storage) {
69-
Ok(mut r) => (r.next(), r.next()),
70-
// don't play with errors
71-
Err(_) => return,
72-
};
73-
match (first, second) {
74-
// The table function evaluated to an empty collection.
75-
(None, None) => {
76-
relation.take_safely(None);
77-
}
78-
// The table function evaluated to a collection with exactly 1 row.
79-
(Some((first_row, Diff::ONE)), None) => {
80-
let types = func.output_type().column_types;
81-
let map_exprs = first_row
82-
.into_iter()
83-
.zip_eq(types)
84-
.map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
85-
.collect();
86-
*relation = input.take_dangerous().map(map_exprs);
65+
}
66+
// For all other table functions (and Wraps that are not covered by the above), check
67+
// whether all arguments are literals (with no errors), in which case we'll evaluate the
68+
// table function and check how many output rows it has, and maybe turn the FlatMap into
69+
// something simpler.
70+
if let MirRelationExpr::FlatMap { func, exprs, input } = relation {
71+
if let Some(args) = exprs
72+
.iter()
73+
.map(|e| e.as_literal_non_error())
74+
.collect::<Option<Vec<_>>>()
75+
{
76+
let temp_storage = RowArena::new();
77+
let (first, second) = match func.eval(&args, &temp_storage) {
78+
Ok(mut r) => (r.next(), r.next()),
79+
// don't play with errors
80+
Err(_) => return,
81+
};
82+
match (first, second) {
83+
// The table function evaluated to an empty collection.
84+
(None, _) => {
85+
relation.take_safely(None);
86+
}
87+
// The table function evaluated to a collection with exactly 1 row.
88+
(Some((first_row, Diff::ONE)), None) => {
89+
let types = func.output_type().column_types;
90+
let map_exprs = first_row
91+
.into_iter()
92+
.zip_eq(types)
93+
.map(|(d, typ)| MirScalarExpr::Literal(Ok(Row::pack_slice(&[d])), typ))
94+
.collect();
95+
*relation = input.take_dangerous().map(map_exprs);
96+
}
97+
// The table function evaluated to a collection with more than 1 row; nothing to do.
98+
_ => {}
8799
}
88-
// The table function evaluated to a collection with more than 1 row; nothing to do.
89-
_ => {}
90100
}
91101
}
92102
}

src/transform/tests/test_transforms/flat_map_elimination.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,4 @@ FlatMap generate_series(5, 6, 1)
152152
Get t0
153153
----
154154
FlatMap generate_series(5, 6, 1)
155-
Get
155+
Get t0

test/sqllogictest/table_func.slt

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ query T multiline
17151715
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
17161716
SELECT *
17171717
FROM x, generate_series(1, x.a)
1718-
WHERE x.a = 1
1718+
WHERE x.a = 1;
17191719
----
17201720
Explained Query:
17211721
Project (#0, #1, #0)
@@ -1732,11 +1732,28 @@ query T multiline
17321732
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
17331733
SELECT *
17341734
FROM x, generate_series(5, x.a)
1735-
WHERE x.a = 1
1735+
WHERE x.a = 1;
17361736
----
17371737
Explained Query:
17381738
Constant <empty>
17391739

17401740
Target cluster: quickstart
17411741

17421742
EOF
1743+
1744+
query T multiline
1745+
EXPLAIN OPTIMIZED PLAN WITH (NO FAST PATH) FOR
1746+
SELECT *
1747+
FROM x, LATERAL (VALUES ((1), (x.a)))
1748+
WHERE x.a = 1;
1749+
----
1750+
Explained Query:
1751+
Project (#0, #1, #0, #0)
1752+
Filter (#0{a} = 1)
1753+
ReadStorage materialize.public.x
1754+
1755+
Source materialize.public.x
1756+
1757+
Target cluster: quickstart
1758+
1759+
EOF

0 commit comments

Comments
 (0)