Skip to content

Commit 5be84e0

Browse files
authored
Adjusting relational optimization pipeline (#422)
Series of (mostly cosmetic) changes to the relational/sql plans generated by adjusting the various IRs, especially the order of relational optimizations. Some of change are required to simplify the end-result of #423, and preventing other regressions to the quality of the final SQL. The changes include: 1. Chang the order of columns in an intermediary `SELECT` clause to have grouping keys _before_ aggregation calls. 2. Moving around the order of the relational optimization pipeline so bubbling & projection pullup happen _very early_, before the first round of filter pushdown, to incentivize getting rid of joins where the RHS is just a values clause to generate scalars. 3. Fixing bugs in projection pullup that were not exposed until (2) was done 4. Making minor improvements to the column bubbler exposed by (2) 5. Removing `@cache` decorators in QDAG when it was discovered some of the test suites are faster without them 6. Adding new `simple_cross` tests
1 parent a594547 commit 5be84e0

File tree

598 files changed

+3264
-3069
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

598 files changed

+3264
-3069
lines changed

pydough/conversion/agg_split.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CallExpression,
1414
ColumnReference,
1515
ColumnReferenceFinder,
16+
ColumnReferenceInputNameRemover,
1617
Join,
1718
JoinType,
1819
LiteralExpression,
@@ -128,7 +129,7 @@ def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> Relationa
128129
new_aggregate: Aggregate = Aggregate(node.input, node.keys, aggs)
129130
project_columns: dict[str, RelationalExpression] = {}
130131
for name, expr in node.keys.items():
131-
project_columns[name] = expr
132+
project_columns[name] = ColumnReference(name, expr.data_type)
132133
project_columns.update(
133134
{name: final_agg_columns[name] for name in node.aggregations}
134135
)
@@ -171,6 +172,8 @@ def transpose_aggregate_join(
171172
agg_input_name: str | None = join.default_input_aliases[agg_side]
172173
need_projection: bool = False
173174

175+
finder: ColumnReferenceFinder = ColumnReferenceFinder()
176+
alias_remover: ColumnReferenceInputNameRemover = ColumnReferenceInputNameRemover()
174177
transposer: ExpressionTranspositionShuttle = ExpressionTranspositionShuttle(
175178
join, False
176179
)
@@ -234,19 +237,44 @@ def transpose_aggregate_join(
234237
for ref in side_keys:
235238
input_keys[ref.name] = ref.with_input(None)
236239
transposer.toggle_keep_input_names(True)
237-
for agg_key in node.keys.values():
240+
for agg_key_name, agg_key in node.keys.items():
241+
finder.reset()
238242
transposed_agg_key = agg_key.accept_shuttle(transposer)
239-
assert isinstance(transposed_agg_key, ColumnReference)
240-
if transposed_agg_key.input_name == agg_input_name:
241-
input_keys[transposed_agg_key.name] = transposed_agg_key.with_input(None)
243+
transposed_agg_key.accept(finder)
244+
if {col.input_name for col in finder.get_column_references()} == {
245+
agg_input_name
246+
}:
247+
if isinstance(transposed_agg_key, ColumnReference):
248+
input_keys[transposed_agg_key.name] = transposed_agg_key.accept_shuttle(
249+
alias_remover
250+
)
251+
else:
252+
if agg_key_name in join.columns and (
253+
agg_key_name in input_keys or agg_key_name in input_aggs
254+
):
255+
# An edge cases that is theoretically possible but never
256+
# encountered so far, and where the behavior is undefined.
257+
raise NotImplementedError("Undefined behavior")
258+
input_keys[agg_key_name] = transposed_agg_key.accept_shuttle(
259+
alias_remover
260+
)
261+
join.columns[agg_key_name] = ColumnReference(
262+
agg_key_name, agg_key.data_type, agg_input_name
263+
)
264+
node.keys[agg_key_name] = ColumnReference(
265+
agg_key_name, agg_key.data_type
266+
)
267+
projection_columns[agg_key_name] = ColumnReference(
268+
agg_key_name, agg_key.data_type
269+
)
242270

243271
# Push the bottom-aggregate beneath the join
244272
join.inputs[agg_side] = Aggregate(agg_input, input_keys, input_aggs)
245273

246274
# Replace the aggregation above the join with the top
247275
# side of the aggregations
248276
node._aggregations = top_aggs
249-
node._columns = {**node.columns, **top_aggs}
277+
node._columns = {**node.keys, **top_aggs}
250278

251279
return need_projection, count_ref
252280

@@ -276,11 +304,19 @@ def attempt_join_aggregate_transpose(
276304
# push the aggregate down.
277305
return node, True
278306

307+
# Verify that all of the aggregation keys strictly come from one side of the
308+
# join.
309+
finder: ColumnReferenceFinder = ColumnReferenceFinder()
310+
for key_expr in node.keys.values():
311+
finder.reset()
312+
key_expr.accept(finder)
313+
if len({ref.input_name for ref in finder.get_column_references()}) > 1:
314+
return node, True
315+
279316
# Break down the aggregation calls by which input they refer to.
280317
lhs_aggs: list[str] = []
281318
rhs_aggs: list[str] = []
282319
count_aggs: list[str] = []
283-
finder: ColumnReferenceFinder = ColumnReferenceFinder()
284320
transposer: ExpressionTranspositionShuttle = ExpressionTranspositionShuttle(
285321
join, True
286322
)
@@ -341,7 +377,10 @@ def attempt_join_aggregate_transpose(
341377
# If we cannot push the aggregate down into either side, we cannot
342378
# perform the transpose.
343379
return node, True
380+
344381
if need_count_aggs and not (can_push_left and can_push_right):
382+
# If we need to push down COUNT(*) aggregates, but cannot push into
383+
# both sides of the join, we cannot perform the transpose.
345384
return node, True
346385

347386
# Parse the join condition to identify the lists of equi-join keys
@@ -365,7 +404,7 @@ def attempt_join_aggregate_transpose(
365404

366405
# Keep a dictionary for the projection columns that will be used to post-process
367406
# the output of the aggregates, if needed.
368-
projection_columns: dict[str, RelationalExpression] = {**node.keys}
407+
projection_columns: dict[str, RelationalExpression] = {}
369408
need_projection: bool = False
370409

371410
# If we need count aggregates, add one to each side of the join.

pydough/conversion/column_bubbler.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
__all__ = ["bubble_column_names"]
88

99

10+
import re
11+
1012
from pydough.relational import (
1113
Aggregate,
1214
CallExpression,
@@ -46,36 +48,42 @@ def name_sort_key(name: str) -> tuple[bool, bool, str]:
4648
)
4749

4850

49-
def generate_agg_name(agg_expr: CallExpression) -> str | None:
51+
def generate_cleaner_names(expr: RelationalExpression, current_name: str) -> list[str]:
5052
"""
51-
Generates a more readable name for an aggregation expression based on its
52-
function name and input column, if applicable. The two patterns of name
53-
generation are:
53+
Generates more readable names for an expression based on its, if applicable.
54+
The patterns of name generation are:
5455
55-
- If the aggregation has a single input that is a column reference, the
56+
- If a function has a single input that is a column reference, the
5657
name is generated as `<function_name>_<column_name>`. For example,
5758
`SUM(sales)` would become `sum_sales`, and `AVG(num_cars_owned)`
5859
would become `avg_num_cars_owned`.
59-
- If the aggregation is a `COUNT` with no inputs, the name is simply
60+
- If an aggregation is a `COUNT` with no inputs, the name is simply
6061
`n_rows`, indicating the number of rows counted.
62+
- If the current name is in the form `name_idx`, try suggesting just `name`.
6163
62-
If neither of these conditions are met, the function returns `None`.
64+
If none of these conditions are met, the function returns an empty list.
6365
6466
Args:
65-
`agg_expr`: The function call expression for which to generate a name,
66-
which is presumed to be an aggregation call.
67+
`expr`: The function call expression for which to generate
68+
alternative names.
69+
`current_name`: The current name of the expression.
6770
6871
Returns:
69-
A string representing the generated name, or `None` if no suitable
70-
name can be generated based on the provided conditions.
72+
A list of strings string representing the candidate generated names.
7173
"""
72-
if len(agg_expr.inputs) == 1:
73-
input_expr = agg_expr.inputs[0]
74-
if isinstance(input_expr, ColumnReference):
75-
return f"{agg_expr.op.function_name.lower()}_{input_expr.name}"
76-
if len(agg_expr.inputs) == 0 and agg_expr.op.function_name.lower() == "count":
77-
return "n_rows"
78-
return None
74+
result: list[str] = []
75+
if isinstance(expr, CallExpression):
76+
if len(expr.inputs) == 1:
77+
input_expr = expr.inputs[0]
78+
if isinstance(input_expr, ColumnReference):
79+
result.append(f"{expr.op.function_name.lower()}_{input_expr.name}")
80+
if len(expr.inputs) == 0 and expr.op.function_name.lower() == "count":
81+
result.append("n_rows")
82+
83+
if not (current_name.startswith("agg") or current_name.startswith("expr")):
84+
if re.match(r"^(.*)_[0-9]+$", current_name):
85+
result.append(re.findall(r"^(.*)_[0-9]+$", current_name)[0])
86+
return result
7987

8088

8189
def run_column_bubbling(
@@ -160,6 +168,17 @@ def run_column_bubbling(
160168
new_ref = remapping[new_ref]
161169
name = new_expr.name
162170
used_names.add(name)
171+
# Try the same thing with generated alternative names
172+
else:
173+
for alt_name in generate_cleaner_names(new_expr, name):
174+
if alt_name not in used_names:
175+
remapping[new_ref] = ColumnReference(
176+
alt_name, new_expr.data_type
177+
)
178+
new_ref = remapping[new_ref]
179+
name = alt_name
180+
used_names.add(name)
181+
break
163182
aliases[new_expr] = new_ref
164183
output_columns[name] = new_expr
165184
# For limit, also transform the orderings if they exist.
@@ -218,14 +237,14 @@ def run_column_bubbling(
218237
# Special case for aggregations: if the existing name is
219238
# bad, try to replace it with a better name based on the
220239
# function name and input column, if applicable.
221-
if name.startswith("agg") or name.startswith("expr"):
222-
alt_name: str | None = generate_agg_name(new_expr)
223-
if alt_name is not None and alt_name not in used_names:
240+
for alt_name in generate_cleaner_names(new_expr, name):
241+
if alt_name not in used_names:
224242
used_names.add(alt_name)
225243
alt_ref = ColumnReference(alt_name, call_expr.data_type)
226244
remapping[new_ref] = alt_ref
227245
new_ref = alt_ref
228246
name = alt_name
247+
break
229248
aliases[new_expr] = new_ref
230249
new_aggs[name] = new_expr
231250
return Aggregate(new_input, new_keys, new_aggs), remapping

pydough/conversion/projection_pullup.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def widen_columns(
5555
# to the calling site.
5656
substitutions: dict[RelationalExpression, RelationalExpression] = {}
5757

58-
# Mapping of every expression in the node's columns to a reference to the
59-
# column of the node that points to it. This is used to keep track of which
60-
# expressions are already present in the node's columns versus the ones that
61-
# should be added to un-prune the node.
58+
# Mapping of every expression in the input nodes columns to a reference to
59+
# the column of the node that points to it. This is used to keep track of
60+
# which expressions are already present in the node's columns versus the
61+
# ones that should be added to un-prune the node.
6262
existing_vals: dict[RelationalExpression, RelationalExpression] = {
6363
expr: ColumnReference(name, expr.data_type)
6464
for name, expr in node.columns.items()
@@ -71,28 +71,25 @@ def widen_columns(
7171
input_alias: str | None = node.default_input_aliases[input_idx]
7272
input_node: RelationalNode = node.inputs[input_idx]
7373
for name, expr in input_node.columns.items():
74-
# If the current node is a Join, add input names to the expression.
75-
if isinstance(node, Join):
76-
expr = add_input_name(expr, input_alias)
77-
ref_expr: ColumnReference = ColumnReference(
74+
ref_expr: RelationalExpression = ColumnReference(
7875
name, expr.data_type, input_name=input_alias
7976
)
77+
8078
# If the expression is not already in the node's columns, then
8179
# inject it so the node can use it later if a pull-up occurs that
8280
# would need to reference this expression.
83-
if expr not in existing_vals:
81+
if ref_expr not in existing_vals:
8482
new_name: str = name
8583
idx: int = 0
8684
while new_name in node.columns:
8785
idx += 1
8886
new_name = f"{name}_{idx}"
8987
new_ref: ColumnReference = ColumnReference(new_name, expr.data_type)
9088
node.columns[new_name] = ref_expr
91-
existing_vals[expr] = ref_expr
92-
if ref_expr != new_ref:
93-
substitutions[ref_expr] = new_ref
94-
elif ref_expr != existing_vals[expr]:
95-
substitutions[ref_expr] = existing_vals[expr]
89+
existing_vals[ref_expr] = new_ref
90+
substitutions[ref_expr] = new_ref
91+
else:
92+
substitutions[ref_expr] = existing_vals[ref_expr]
9693

9794
# Return the substitution mapping
9895
return substitutions

pydough/conversion/relational_converter.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,44 +1443,59 @@ def optimize_relational_tree(
14431443
The optimized relational root.
14441444
"""
14451445

1446-
# Step 0: prune unused columns. This is done early to remove as many dead
1446+
# Start by pruning unused columns. This is done early to remove as many dead
14471447
# names as possible so that steps that require generating column names can
14481448
# use nicer names instead of generating nastier ones to avoid collisions.
14491449
# It also speeds up all subsequent steps by reducing the total number of
14501450
# objects inside the plan.
1451-
root = ColumnPruner().prune_unused_columns(root)
1451+
pruner: ColumnPruner = ColumnPruner()
1452+
root = pruner.prune_unused_columns(root)
1453+
1454+
# Bubble up names from the leaf nodes to further encourage simpler naming
1455+
# without aliases, and also to delete duplicate columns where possible.
1456+
# This is done early to maximize the chances that a nicer name will be used
1457+
# for aggregations before projection pullup eliminates many of those names
1458+
# by pulling the aggregated expression inputs into the aggregate call.
1459+
root = bubble_column_names(root)
1460+
1461+
# Run projection pullup to move projections as far up the tree as possible.
1462+
# This is done as soon as possible to make joins redundant if they only
1463+
# exist to compute a scalar projection and then link it with the data.
1464+
# print()
1465+
# print(root.to_tree_string())
1466+
root = confirm_root(pullup_projections(root))
1467+
# print()
1468+
# print(root.to_tree_string())
14521469

1453-
# Step 1: push filters down as far as possible
1470+
# Push filters down as far as possible
14541471
root = confirm_root(push_filters(root, configs))
14551472

1456-
# Step 2: merge adjacent projections, unless it would result in excessive
1457-
# duplicate subexpression computations.
1473+
# Merge adjacent projections, unless it would result in excessive duplicate
1474+
# subexpression computations.
14581475
root = confirm_root(merge_projects(root))
14591476

1460-
# Step 3: split aggregations on top of joins so part of the aggregate
1461-
# happens underneath the join.
1477+
# Split aggregations on top of joins so part of the aggregate happens
1478+
# underneath the join.
14621479
root = confirm_root(split_partial_aggregates(root, configs))
14631480

1464-
# Step 4: delete aggregations that are inferred to be redundant due to
1465-
# operating on already unique data.
1481+
# Delete aggregations that are inferred to be redundant due to operating on
1482+
# already unique data.
14661483
root = remove_redundant_aggs(root)
14671484

1468-
# Step 5: re-run projection merging since the removal of redundant
1469-
# aggregations may have created redundant projections that can be deleted.
1485+
# Re-run projection merging since the removal of redundant aggregations may
1486+
# have created redundant projections that can be deleted.
14701487
root = confirm_root(merge_projects(root))
14711488

1472-
# Step 6: re-run column pruning after the various steps, which may have
1473-
# rendered more columns unused. This is done befre the next step to remove
1474-
# as many column names as possible so the column bubbling step can try to
1475-
# use nicer names without worrying about collisions.
1476-
root = ColumnPruner().prune_unused_columns(root)
1489+
# Re-run column pruning after the various steps, which may have rendered
1490+
# more columns unused. This is done befre the next step to remove as many
1491+
# column names as possible so the column bubbling step can try to use nicer
1492+
# names without worrying about collisions.
1493+
root = pruner.prune_unused_columns(root)
14771494

1478-
# Step 7: bubble up names from the leaf nodes to further encourage simpler
1479-
# naming without aliases, and also to delete duplicate columns where
1480-
# possible.
1495+
# Re-run column bubbling now that the columns have been pruned again.
14811496
root = bubble_column_names(root)
14821497

1483-
# Step 8: the following pipeline twice:
1498+
# Run the following pipeline twice:
14841499
# A: projection pullup
14851500
# B: expression simplification
14861501
# C: filter pushdown
@@ -1494,21 +1509,20 @@ def optimize_relational_tree(
14941509
root = confirm_root(pullup_projections(root))
14951510
simplify_expressions(root, configs, additional_shuttles)
14961511
root = confirm_root(push_filters(root, configs))
1497-
root = ColumnPruner().prune_unused_columns(root)
1512+
root = pruner.prune_unused_columns(root)
14981513

1499-
# Step 9: re-run projection merging, without pushing into joins. This
1500-
# will allow some redundant projections created by pullup to be removed
1501-
# entirely.
1514+
# Re-run projection merging, without pushing into joins. This will allow
1515+
# some redundant projections created by pullup to be removed entirely.
15021516
root = confirm_root(merge_projects(root, push_into_joins=False))
15031517

1504-
# Step 10: re-run column bubbling to further simplify the final names of
1505-
# columns in the output now that more columns have been pruned, and delete
1506-
# any new duplicate columns that were created during the pullup step.
1518+
# Re-run column bubbling to further simplify the final names of columns in
1519+
# the output now that more columns have been pruned, and delete any new
1520+
# duplicate columns that were created during the pullup step.
15071521
root = bubble_column_names(root)
15081522

1509-
# Step 11: re-run column pruning one last time to remove any columns that
1510-
# are no longer used after the final round of transformations.
1511-
root = ColumnPruner().prune_unused_columns(root)
1523+
# Re-run column pruning one last time to remove any columns that are no
1524+
# longer used after the final round of transformations.
1525+
root = pruner.prune_unused_columns(root)
15121526

15131527
return root
15141528

0 commit comments

Comments
 (0)