|
18 | 18 | import functools |
19 | 19 | import itertools |
20 | 20 | from typing import ( |
| 21 | + Callable, |
21 | 22 | cast, |
| 23 | + Dict, |
| 24 | + Generator, |
22 | 25 | Hashable, |
23 | 26 | Iterable, |
24 | 27 | Iterator, |
|
40 | 43 |
|
41 | 44 | _MAX_INLINE_COMPLEXITY = 10 |
42 | 45 |
|
| 46 | +T = TypeVar("T") |
| 47 | + |
| 48 | + |
| 49 | +def unique_nodes( |
| 50 | + roots: Sequence[expression.Expression], |
| 51 | +) -> Generator[expression.Expression, None, None]: |
| 52 | + """Walks the tree for unique nodes""" |
| 53 | + seen = set() |
| 54 | + stack: list[expression.Expression] = list(roots) |
| 55 | + while stack: |
| 56 | + item = stack.pop() |
| 57 | + if item not in seen: |
| 58 | + yield item |
| 59 | + seen.add(item) |
| 60 | + stack.extend(item.children) |
| 61 | + |
| 62 | + |
| 63 | +def iter_nodes_topo( |
| 64 | + roots: Sequence[expression.Expression], |
| 65 | +) -> Generator[expression.Expression, None, None]: |
| 66 | + """Returns nodes in reverse topological order, using Kahn's algorithm.""" |
| 67 | + child_to_parents: Dict[ |
| 68 | + expression.Expression, list[expression.Expression] |
| 69 | + ] = collections.defaultdict(list) |
| 70 | + out_degree: Dict[expression.Expression, int] = collections.defaultdict(int) |
| 71 | + |
| 72 | + queue: collections.deque[expression.Expression] = collections.deque() |
| 73 | + for node in unique_nodes(roots): |
| 74 | + num_children = len(node.children) |
| 75 | + out_degree[node] = num_children |
| 76 | + if num_children == 0: |
| 77 | + queue.append(node) |
| 78 | + for child in node.children: |
| 79 | + child_to_parents[child].append(node) |
| 80 | + |
| 81 | + while queue: |
| 82 | + item = queue.popleft() |
| 83 | + yield item |
| 84 | + parents = child_to_parents.get(item, []) |
| 85 | + for parent in parents: |
| 86 | + out_degree[parent] -= 1 |
| 87 | + if out_degree[parent] == 0: |
| 88 | + queue.append(parent) |
| 89 | + |
| 90 | + |
| 91 | +def reduce_up( |
| 92 | + roots: Sequence[expression.Expression], |
| 93 | + reduction: Callable[[expression.Expression, Tuple[T, ...]], T], |
| 94 | +) -> Tuple[T, ...]: |
| 95 | + """Apply a bottom-up reduction to the forest.""" |
| 96 | + results: dict[expression.Expression, T] = {} |
| 97 | + for node in list(iter_nodes_topo(roots)): |
| 98 | + # child nodes have already been transformed |
| 99 | + child_results = tuple(results[child] for child in node.children) |
| 100 | + result = reduction(node, child_results) |
| 101 | + results[node] = result |
| 102 | + |
| 103 | + return tuple(results[root] for root in roots) |
| 104 | + |
43 | 105 |
|
44 | 106 | def apply_col_exprs_to_plan( |
45 | 107 | plan: nodes.BigFrameNode, col_exprs: Sequence[nodes.ColumnDef] |
46 | 108 | ) -> nodes.BigFrameNode: |
47 | | - # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions |
48 | 109 | target_ids = tuple(named_expr.id for named_expr in col_exprs) |
49 | 110 |
|
50 | | - fragments = tuple( |
51 | | - itertools.chain.from_iterable( |
52 | | - fragmentize_expression(expr) for expr in col_exprs |
53 | | - ) |
54 | | - ) |
| 111 | + fragments = fragmentize_expression(col_exprs) |
55 | 112 | return push_into_tree(plan, fragments, target_ids) |
56 | 113 |
|
57 | 114 |
|
@@ -101,14 +158,26 @@ class FactoredExpression: |
101 | 158 | sub_exprs: Tuple[nodes.ColumnDef, ...] |
102 | 159 |
|
103 | 160 |
|
104 | | -def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]: |
| 161 | +def fragmentize_expression( |
| 162 | + roots: Sequence[nodes.ColumnDef], |
| 163 | +) -> Sequence[nodes.ColumnDef]: |
105 | 164 | """ |
106 | 165 | The goal of this functions is to factor out an expression into multiple sub-expressions. |
107 | 166 | """ |
108 | | - |
109 | | - factored_expr = root.expression.reduce_up(gather_fragments) |
110 | | - root_expr = nodes.ColumnDef(factored_expr.root_expr, root.id) |
111 | | - return (root_expr, *factored_expr.sub_exprs) |
| 167 | + # TODO: Fragmentize a bit less aggressively |
| 168 | + factored_exprs = reduce_up([root.expression for root in roots], gather_fragments) |
| 169 | + root_exprs = ( |
| 170 | + nodes.ColumnDef(factored.root_expr, root.id) |
| 171 | + for factored, root in zip(factored_exprs, roots) |
| 172 | + ) |
| 173 | + return ( |
| 174 | + *root_exprs, |
| 175 | + *dedupe( |
| 176 | + itertools.chain.from_iterable( |
| 177 | + factored_expr.sub_exprs for factored_expr in factored_exprs |
| 178 | + ) |
| 179 | + ), |
| 180 | + ) |
112 | 181 |
|
113 | 182 |
|
114 | 183 | @dataclasses.dataclass(frozen=True, eq=False) |
|
0 commit comments