7
7
// the Business Source License, use of this software will be governed
8
8
// by the Apache License, Version 2.0.
9
9
10
- //! Turns `FlatMap` into `Map` if only one row is produced by flatmap.
11
- //!
10
+ //! For a `FlatMap` whose args are all constants, turns it into `Map` if only 1 row is produced by
11
+ //! the table function, or turns it into an empty constant if 0 rows are produced by the table
12
+ //! function. Additionally, a `Wrap` whose width is larger than its number of arguments can be
13
+ //! removed.
12
14
15
+ use itertools:: Itertools ;
13
16
use mz_expr:: visit:: Visit ;
14
17
use mz_expr:: { MirRelationExpr , MirScalarExpr , TableFunc } ;
15
- use mz_repr:: { Datum , Diff , ScalarType } ;
18
+ use mz_repr:: { Diff , Row , RowArena } ;
16
19
17
20
use crate :: TransformCtx ;
18
21
19
- /// Turns `FlatMap` into `Map` if only one row is produced by flatmap .
22
+ /// See comment at the top of the file .
20
23
#[ derive( Debug ) ]
21
24
pub struct FlatMapElimination ;
22
25
@@ -42,70 +45,49 @@ impl crate::Transform for FlatMapElimination {
42
45
}
43
46
44
47
impl FlatMapElimination {
45
- /// Turns `FlatMap` into `Map` if only one row is produced by flatmap .
48
+ /// See comment at the top of the file .
46
49
pub fn action ( relation : & mut MirRelationExpr ) {
47
50
if let MirRelationExpr :: FlatMap { func, exprs, input } = relation {
48
- let ( func, with_ordinality) = if let TableFunc :: WithOrdinality ( with_ordinality) = func {
49
- // get to the actual function, but remember that we have a WITH ORDINALITY clause.
50
- ( & * with_ordinality. inner , true )
51
- } else {
52
- ( & * func, false )
53
- } ;
54
-
55
- if let TableFunc :: GuardSubquerySize { .. } = func {
56
- // (`with_ordinality` doesn't matter because this function never emits rows)
57
- if let Some ( 1 ) = exprs[ 0 ] . as_literal_int64 ( ) {
58
- relation. take_safely ( None ) ;
59
- }
60
- } else if let TableFunc :: Wrap { width, .. } = func {
51
+ // Treat Wrap specially.
52
+ if let TableFunc :: Wrap { width, .. } = func {
61
53
if * width >= exprs. len ( ) {
62
54
* relation = input. take_dangerous ( ) . map ( std:: mem:: take ( exprs) ) ;
63
- if with_ordinality {
64
- * relation = relation. take_dangerous ( ) . map_one ( MirScalarExpr :: literal (
65
- Ok ( Datum :: Int64 ( 1 ) ) ,
66
- ScalarType :: Int64 ,
67
- ) ) ;
68
- }
55
+ return ;
56
+ }
57
+ }
58
+ // For all other table functions, check for all arguments being literals.
59
+ let mut args = vec ! [ ] ;
60
+ for e in exprs {
61
+ match e. as_literal ( ) {
62
+ Some ( Ok ( datum) ) => args. push ( datum) ,
63
+ // Give up if any arg is not a literal, or if it's a literal error.
64
+ _ => return ,
69
65
}
70
- } else if is_supported_unnest ( func) {
71
- let func = func. clone ( ) ;
72
- let exprs = exprs. clone ( ) ;
73
- use mz_expr:: MirScalarExpr ;
74
- use mz_repr:: RowArena ;
75
- if let MirScalarExpr :: Literal ( Ok ( row) , ..) = & exprs[ 0 ] {
76
- let temp_storage = RowArena :: default ( ) ;
77
- if let Ok ( mut iter) = func. eval ( & [ row. iter ( ) . next ( ) . unwrap ( ) ] , & temp_storage) {
78
- match ( iter. next ( ) , iter. next ( ) ) {
79
- ( None , _) => {
80
- // If there are no elements in the literal argument, no output.
81
- relation. take_safely ( None ) ;
82
- }
83
- ( Some ( ( row, Diff :: ONE ) ) , None ) => {
84
- assert_eq ! ( func. output_type( ) . column_types. len( ) , 1 ) ;
85
- * relation =
86
- input. take_dangerous ( ) . map ( vec ! [ MirScalarExpr :: Literal (
87
- Ok ( row) ,
88
- func. output_type( ) . column_types[ 0 ] . clone( ) ,
89
- ) ] ) ;
90
- if with_ordinality {
91
- * relation =
92
- relation. take_dangerous ( ) . map_one ( MirScalarExpr :: literal (
93
- Ok ( Datum :: Int64 ( 1 ) ) ,
94
- ScalarType :: Int64 ,
95
- ) ) ;
96
- }
97
- }
98
- _ => { }
99
- }
100
- } ;
66
+ }
67
+ let temp_storage = RowArena :: new ( ) ;
68
+ let ( first, second) = match func. eval ( & args, & temp_storage) {
69
+ Ok ( mut r) => ( r. next ( ) , r. next ( ) ) ,
70
+ // don't play with errors
71
+ Err ( _) => return ,
72
+ } ;
73
+ match ( first, second) {
74
+ // The table function evaluated to an empty collection.
75
+ ( None , None ) => {
76
+ relation. take_safely ( None ) ;
101
77
}
78
+ // The table function evaluated to a collection with exactly 1 row.
79
+ ( Some ( ( first_row, Diff :: ONE ) ) , None ) => {
80
+ let types = func. output_type ( ) . column_types ;
81
+ let map_exprs = first_row
82
+ . into_iter ( )
83
+ . zip_eq ( types)
84
+ . map ( |( d, typ) | MirScalarExpr :: Literal ( Ok ( Row :: pack_slice ( & [ d] ) ) , typ) )
85
+ . collect ( ) ;
86
+ * relation = input. take_dangerous ( ) . map ( map_exprs) ;
87
+ }
88
+ // The table function evaluated to a collection with more than 1 row; nothing to do.
89
+ _ => { }
102
90
}
103
91
}
104
92
}
105
93
}
106
-
107
- /// Returns `true` for `unnest_~` variants supported by [`FlatMapElimination`].
108
- fn is_supported_unnest ( func : & TableFunc ) -> bool {
109
- use TableFunc :: * ;
110
- matches ! ( func, UnnestArray { .. } | UnnestList { .. } )
111
- }
0 commit comments