Skip to content

Commit c2405ac

Browse files
authored
Adjusting how hybrid backrefs are derived to prevent bugs with partitions (#429)
Adjusting how back references are converted from QDAG to Hybrid. These changes eliminate the recursive procedure which utilizes the `back_levels` from the QDAG expression, and replaces it with an iterative solution that only uses the `ancestral_mapping` dictionary, without converting hybrid nodes back into faux QDAG nodes just so they can be recursively re-converted. This required several qdag->relational tests to be adjusted since the qdag structures of those tests were malformed (contained back references to names that had not been pinned with a calculate). Also, after adding a HybridPartition to the pipeline of a HybridTree, appends a Noop afterwards. This ensures that there is a buffer in the `max_steps` field between the child being partitioned by a `HybridPartition` and any other children.
1 parent 414d538 commit c2405ac

20 files changed

+1403
-142
lines changed

pydough/conversion/hybrid_decorrelater.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def make_decorrelate_parent(
6767
to the PARTITION edge case. The third entry is how many operators
6868
in the pipeline were copied over from the root.
6969
"""
70-
if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0:
70+
if (
71+
isinstance(hybrid.pipeline[0], HybridPartition)
72+
and hybrid.children[child_idx].max_steps == 1
73+
):
7174
# Special case: if the correlated child is the data argument of a
7275
# partition operation, then the parent to snapshot is actually the
7376
# parent of the level containing the partition operation. In this

pydough/conversion/hybrid_translator.py

Lines changed: 84 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
HybridCollectionAccess,
6767
HybridFilter,
6868
HybridLimit,
69+
HybridNoop,
6970
HybridOperation,
7071
HybridPartition,
7172
HybridPartitionChild,
@@ -421,6 +422,7 @@ def populate_children(
421422
if (
422423
name in hybrid.ancestral_mapping
423424
or name in hybrid.pipeline[-1].terms
425+
or subtree.ancestral_mapping[name] == 0
424426
):
425427
continue
426428
hybrid_back_expr = self.make_hybrid_expr(
@@ -867,97 +869,6 @@ def rewrite_quantile_call(
867869

868870
return max_call
869871

870-
def make_hybrid_correl_expr(
871-
self,
872-
back_expr: BackReferenceExpression,
873-
collection: PyDoughCollectionQDAG,
874-
steps_taken_so_far: int,
875-
down_shift: int,
876-
) -> HybridCorrelExpr:
877-
"""
878-
Converts a BACK reference into a correlated reference when the number
879-
of BACK levels exceeds the height of the current subtree.
880-
881-
Args:
882-
`back_expr`: the original BACK reference to be converted.
883-
`collection`: the collection at the top of the current subtree,
884-
before we have run out of BACK levels to step up out of.
885-
`steps_taken_so_far`: the number of steps already taken to step
886-
up from the BACK node. This is needed so we know how many steps
887-
still need to be taken upward once we have stepped out of the child
888-
subtree back into the parent subtree.
889-
`down_shift`: a factor that should be subtracted from the final
890-
back shift when creating a term in the form CORREL(BACK(n).x), to
891-
account for edge cases involving PARTITION nodes. Starts as 0, and
892-
is incremented as-needed.
893-
"""
894-
if len(self.stack) == 0:
895-
raise ValueError("Back reference steps too far back")
896-
# Identify the parent subtree that the BACK reference is stepping back
897-
# into, out of the child.
898-
parent_tree: HybridTree = self.stack.pop()
899-
remaining_steps_back: int = back_expr.back_levels - steps_taken_so_far - 1
900-
parent_result: HybridExpr
901-
new_expr: PyDoughExpressionQDAG
902-
# Special case: stepping out of the data argument of PARTITION back
903-
# into its ancestor. For example:
904-
# TPCH.CALCULATE(x=...).PARTITION(data.WHERE(y > BACK(1).x), ...)
905-
partition_edge_case: bool = len(parent_tree.pipeline) == 1 and isinstance(
906-
parent_tree.pipeline[0], HybridPartition
907-
)
908-
if partition_edge_case:
909-
next_hybrid: HybridTree
910-
if parent_tree.parent is not None:
911-
# If the parent tree has a parent, then we can step back
912-
# into the parent tree's parent, which is the context for
913-
# the partition.
914-
next_hybrid = parent_tree.parent
915-
else:
916-
assert len(self.stack) > 0, "Back reference steps too far back"
917-
next_hybrid = self.stack[-1]
918-
# Treat the partition's parent as the context for the back
919-
# to step into, as opposed to the partition itself (so the back
920-
# levels are consistent)
921-
self.stack.append(next_hybrid)
922-
parent_result = self.make_hybrid_correl_expr(
923-
back_expr, collection, steps_taken_so_far, down_shift + 1
924-
).expr
925-
self.stack.pop()
926-
elif remaining_steps_back == 0:
927-
# If there are no more steps back to be made, then the correlated
928-
# reference is to a reference from the current context.
929-
if back_expr.term_name in parent_tree.ancestral_mapping:
930-
new_expr = BackReferenceExpression(
931-
collection,
932-
back_expr.term_name,
933-
parent_tree.ancestral_mapping[back_expr.term_name] - down_shift,
934-
)
935-
parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False)
936-
elif back_expr.term_name in parent_tree.pipeline[-1].terms:
937-
parent_name: str = parent_tree.pipeline[-1].renamings.get(
938-
back_expr.term_name, back_expr.term_name
939-
)
940-
parent_result = HybridRefExpr(parent_name, back_expr.pydough_type)
941-
else:
942-
raise ValueError(
943-
f"Back reference to {back_expr.term_name} not found in parent"
944-
)
945-
else:
946-
# Otherwise, a back reference needs to be made from the current
947-
# collection a number of steps back based on how many steps still
948-
# need to be taken, and it must be recursively converted to a
949-
# hybrid expression that gets wrapped in a correlated reference.
950-
new_expr = BackReferenceExpression(
951-
collection, back_expr.term_name, remaining_steps_back
952-
)
953-
parent_result = self.make_hybrid_expr(parent_tree, new_expr, {}, False)
954-
# Restore parent_tree back onto the stack, since evaluating `back_expr`
955-
# does not change the program's current placement in the subtrees.
956-
self.stack.append(parent_tree)
957-
# Create the correlated reference to the expression with regards to
958-
# the parent tree, which could also be a correlated expression.
959-
return HybridCorrelExpr(parent_result)
960-
961872
def add_unique_terms(
962873
self,
963874
hybrid: HybridTree,
@@ -1046,6 +957,74 @@ def add_unique_terms(
1046957
child_idx,
1047958
)
1048959

960+
def translate_back_reference(
961+
self, hybrid: HybridTree, expr: BackReferenceExpression
962+
) -> HybridExpr:
963+
"""
964+
Perform the logic used to translate a BACK reference in QDAG into a
965+
back reference in hybrid, or a correlated reference if the back
966+
reference steps back further than the height of the current hybrid
967+
tree.
968+
969+
Args:
970+
`hybrid`: the hybrid tree that should be used to derive the
971+
translation of `expr`, as it is the context in which the `expr`
972+
will live.
973+
`expr`: the BACK reference to be converted.
974+
975+
Returns:
976+
The HybridExpr node corresponding to `expr`.
977+
"""
978+
back_levels: int = 0
979+
correl_levels: int = 0
980+
new_stack: list[HybridTree] = []
981+
ancestor_tree: HybridTree = hybrid
982+
expr_name: str = expr.term_name
983+
# Start with the current context and hunt for an ancestor with
984+
# that name pinned by a CALCULATE (so it is in the ancestral
985+
# mapping), and make sure it is within the height bounds of the
986+
# current tree. If not, then pop the previous tree from the
987+
# stack and look there, repeating until one is found or the
988+
# stack is exhausted. Keep track of how many times we step
989+
# outward, since this is how many CORREL() layers we need to
990+
# wrap the final expression in.
991+
while True:
992+
if (
993+
expr.term_name in ancestor_tree.ancestral_mapping
994+
and ancestor_tree.ancestral_mapping[expr.term_name]
995+
< ancestor_tree.get_tree_height()
996+
):
997+
back_levels = ancestor_tree.ancestral_mapping[expr.term_name]
998+
for _ in range(back_levels):
999+
assert ancestor_tree.parent is not None
1000+
ancestor_tree = ancestor_tree.parent
1001+
expr_name = ancestor_tree.pipeline[-1].renamings.get(
1002+
expr_name, expr_name
1003+
)
1004+
break
1005+
elif len(self.stack) > 0:
1006+
ancestor_tree = self.stack.pop()
1007+
new_stack.append(ancestor_tree)
1008+
correl_levels += 1
1009+
else:
1010+
raise ValueError("Cannot find ancestor with name " + str(expr))
1011+
for tree in reversed(new_stack):
1012+
self.stack.append(tree)
1013+
1014+
# The final expression is a regular or back reference depending
1015+
# on how many back levels it is from the identified ancestor.
1016+
result: HybridExpr
1017+
if back_levels == 0:
1018+
result = HybridRefExpr(expr_name, expr.pydough_type)
1019+
else:
1020+
result = HybridBackRefExpr(expr_name, back_levels, expr.pydough_type)
1021+
1022+
# Then, wrap it in the necessary number of CORREL() layers.
1023+
for _ in range(correl_levels):
1024+
result = HybridCorrelExpr(result)
1025+
1026+
return result
1027+
10491028
def make_hybrid_expr(
10501029
self,
10511030
hybrid: HybridTree,
@@ -1074,7 +1053,6 @@ def make_hybrid_expr(
10741053
child_connection: HybridConnection
10751054
args: list[HybridExpr] = []
10761055
hybrid_arg: HybridExpr
1077-
ancestor_tree: HybridTree
10781056
collection: PyDoughCollectionQDAG
10791057
match expr:
10801058
case PartitionKey():
@@ -1105,30 +1083,8 @@ def make_hybrid_expr(
11051083
else:
11061084
return HybridRefExpr(expr.term_name, expr.pydough_type)
11071085
case BackReferenceExpression():
1108-
# A reference to an expression from an ancestor becomes a
1109-
# reference to one of the terms of a parent level of the hybrid
1110-
# tree. If the BACK goes far enough that it must step outside
1111-
# a child subtree into the parent, a correlated reference is
1112-
# created.
1113-
ancestor_tree = hybrid
1114-
true_steps_back: int = 0
1115-
# Keep stepping backward until `expr.back_levels` non-hidden
1116-
# steps have been taken.
1117-
collection = expr.collection
1118-
while true_steps_back < expr.back_levels:
1119-
assert collection.ancestor_context is not None
1120-
collection = collection.ancestor_context
1121-
if ancestor_tree.parent is None:
1122-
return self.make_hybrid_correl_expr(
1123-
expr, collection, true_steps_back, 0
1124-
)
1125-
ancestor_tree = ancestor_tree.parent
1126-
if not ancestor_tree.is_hidden_level:
1127-
true_steps_back += 1
1128-
expr_name = ancestor_tree.pipeline[-1].renamings.get(
1129-
expr.term_name, expr.term_name
1130-
)
1131-
return HybridBackRefExpr(expr_name, expr.back_levels, expr.pydough_type)
1086+
return self.translate_back_reference(hybrid, expr)
1087+
11321088
case Reference():
11331089
if hybrid.ancestral_mapping.get(expr.term_name, 0) > 0:
11341090
collection = expr.collection
@@ -1497,6 +1453,8 @@ def make_hybrid_tree(
14971453
hybrid.pipeline[-1].orderings,
14981454
)
14991455
)
1456+
for name in new_expressions:
1457+
hybrid.ancestral_mapping[name] = 0
15001458
return hybrid
15011459
case Singular():
15021460
# a Singular node is just used to annotate the preceding context
@@ -1551,6 +1509,11 @@ def make_hybrid_tree(
15511509
].subtree
15521510
partition_child.agg_keys = key_exprs
15531511
partition_child.join_keys = [(k, k) for k in key_exprs]
1512+
# Add a dummy no-op after the partition to ensure a
1513+
# buffer in the max_steps for other children.
1514+
successor_hybrid.add_operation(
1515+
HybridNoop(successor_hybrid.pipeline[-1])
1516+
)
15541517
return successor_hybrid
15551518
case OrderBy() | TopK():
15561519
hybrid = self.make_hybrid_tree(
@@ -1630,6 +1593,11 @@ def make_hybrid_tree(
16301593
successor_hybrid.children[
16311594
partition_child_idx
16321595
].subtree.agg_keys = key_exprs
1596+
# Add a dummy no-op after the partition to ensure a
1597+
# buffer in the max_steps for other children.
1598+
successor_hybrid.add_operation(
1599+
HybridNoop(successor_hybrid.pipeline[-1])
1600+
)
16331601
case GlobalContext():
16341602
# This is a special case where the child access
16351603
# is a global context, which means that the child is

pydough/conversion/hybrid_tree.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def general_join_condition(self, condition: HybridExpr | None) -> None:
244244
"""
245245
self._general_join_condition = condition
246246

247+
def get_tree_height(self) -> int:
248+
"""
249+
Returns the number of levels in the hybrid tree, starting from the
250+
current level and counting upward.
251+
"""
252+
parent_height: int = 0 if self.parent is None else self.parent.get_tree_height()
253+
return parent_height + 1
254+
247255
def add_operation(self, operation: HybridOperation) -> None:
248256
"""
249257
Appends a new hybrid operation to the end of the hybrid tree's pipeline.

pydough/conversion/relational_converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,10 @@ def handle_children(
750750
# If handling the data for a partition, pull every aggregation key
751751
# into the current context since it is now accessible as a normal
752752
# ref instead of a child ref.
753-
if isinstance(hybrid.pipeline[pipeline_idx], HybridPartition):
753+
if (
754+
isinstance(hybrid.pipeline[pipeline_idx], HybridPartition)
755+
and child_idx == 0
756+
):
754757
partition = hybrid.pipeline[pipeline_idx]
755758
assert isinstance(partition, HybridPartition)
756759
for key_name in partition.key_names:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
ROOT(columns=[('quarter', quarter), ('n_incidents', DEFAULT_TO(ndistinct_in_device_id, 0:numeric)), ('n_sold', DEFAULT_TO(n_rows, 0:numeric)), ('quarter_cum', ROUND(RELSUM(args=[DEFAULT_TO(ndistinct_in_device_id, 0:numeric)], partition=[], order=[(quarter):asc_last], cumulative=True) / RELSUM(args=[DEFAULT_TO(n_rows, 0:numeric)], partition=[], order=[(quarter):asc_last], cumulative=True), 2:numeric))], orderings=[(quarter):asc_first])
2+
JOIN(condition=t0.quarter == t1.quarter, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_rows': t1.n_rows, 'ndistinct_in_device_id': t0.ndistinct_in_device_id, 'quarter': t0.quarter})
3+
JOIN(condition=t0.quarter == t1.quarter, type=LEFT, cardinality=SINGULAR_FILTER, columns={'ndistinct_in_device_id': t1.ndistinct_in_device_id, 'quarter': t0.quarter})
4+
AGGREGATE(keys={'quarter': DATETIME(ca_dt, 'start of quarter':string)}, aggregations={})
5+
JOIN(condition=t0.ca_dt < DATETIME(t1.pr_release, '+2 years':string, 'start of quarter':string) & t0.ca_dt >= t1.pr_release, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt})
6+
SCAN(table=main.CALENDAR, columns={'ca_dt': ca_dt})
7+
FILTER(condition=pr_name == 'RubyCopper-Star':string, columns={'pr_release': pr_release})
8+
SCAN(table=main.PRODUCTS, columns={'pr_name': pr_name, 'pr_release': pr_release})
9+
AGGREGATE(keys={'quarter': DATETIME(ca_dt, 'start of quarter':string)}, aggregations={'ndistinct_in_device_id': NDISTINCT(in_device_id)})
10+
JOIN(condition=t0.in_device_id == t1.de_id & t1.de_product_id == t0.pr_id, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt, 'in_device_id': t0.in_device_id})
11+
JOIN(condition=t0.in_repair_country_id == t1.co_id, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt, 'in_device_id': t0.in_device_id, 'pr_id': t1.pr_id})
12+
JOIN(condition=t0.ca_dt == DATETIME(t1.in_error_report_ts, 'start of day':string), type=INNER, cardinality=PLURAL_FILTER, columns={'ca_dt': t0.ca_dt, 'in_device_id': t1.in_device_id, 'in_repair_country_id': t1.in_repair_country_id})
13+
JOIN(condition=t0.ca_dt < DATETIME(t1.pr_release, '+2 years':string, 'start of quarter':string) & t0.ca_dt >= t1.pr_release, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt})
14+
SCAN(table=main.CALENDAR, columns={'ca_dt': ca_dt})
15+
FILTER(condition=pr_name == 'RubyCopper-Star':string, columns={'pr_release': pr_release})
16+
SCAN(table=main.PRODUCTS, columns={'pr_name': pr_name, 'pr_release': pr_release})
17+
SCAN(table=main.INCIDENTS, columns={'in_device_id': in_device_id, 'in_error_report_ts': in_error_report_ts, 'in_repair_country_id': in_repair_country_id})
18+
JOIN(condition=True:bool, type=INNER, cardinality=SINGULAR_FILTER, columns={'co_id': t1.co_id, 'pr_id': t0.pr_id})
19+
FILTER(condition=pr_name == 'RubyCopper-Star':string, columns={'pr_id': pr_id})
20+
SCAN(table=main.PRODUCTS, columns={'pr_id': pr_id, 'pr_name': pr_name})
21+
FILTER(condition=co_name == 'CN':string, columns={'co_id': co_id})
22+
SCAN(table=main.COUNTRIES, columns={'co_id': co_id, 'co_name': co_name})
23+
SCAN(table=main.DEVICES, columns={'de_id': de_id, 'de_product_id': de_product_id})
24+
AGGREGATE(keys={'quarter': DATETIME(ca_dt, 'start of quarter':string)}, aggregations={'n_rows': COUNT()})
25+
JOIN(condition=t0.de_product_id == t1.pr_id, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt})
26+
JOIN(condition=t0.ca_dt == DATETIME(t1.de_purchase_ts, 'start of day':string), type=INNER, cardinality=PLURAL_FILTER, columns={'ca_dt': t0.ca_dt, 'de_product_id': t1.de_product_id})
27+
JOIN(condition=t0.ca_dt < DATETIME(t1.pr_release, '+2 years':string, 'start of quarter':string) & t0.ca_dt >= t1.pr_release, type=INNER, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt})
28+
SCAN(table=main.CALENDAR, columns={'ca_dt': ca_dt})
29+
FILTER(condition=pr_name == 'RubyCopper-Star':string, columns={'pr_release': pr_release})
30+
SCAN(table=main.PRODUCTS, columns={'pr_name': pr_name, 'pr_release': pr_release})
31+
SCAN(table=main.DEVICES, columns={'de_product_id': de_product_id, 'de_purchase_ts': de_purchase_ts})
32+
FILTER(condition=pr_name == 'RubyCopper-Star':string, columns={'pr_id': pr_id})
33+
SCAN(table=main.PRODUCTS, columns={'pr_id': pr_id, 'pr_name': pr_name})

tests/test_plan_refsols/lines_shipping_vs_customer_region.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
ROOT(columns=[('order_year', YEAR(o_orderdate)), ('customer_region_name', r_name), ('customer_nation_name', name_3), ('supplier_region_name', name_16), ('nation_name', n_name)], orderings=[])
2-
JOIN(condition=t0.l_partkey == t1.ps_partkey & t0.l_suppkey == t1.ps_suppkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_name': t1.n_name, 'name_16': t1.r_name, 'name_3': t0.n_name, 'o_orderdate': t0.o_orderdate, 'r_name': t0.r_name})
1+
ROOT(columns=[('order_year', YEAR(o_orderdate)), ('customer_region_name', r_name), ('customer_nation_name', nname), ('supplier_region_name', name_16), ('nation_name', n_name)], orderings=[])
2+
JOIN(condition=t0.l_partkey == t1.ps_partkey & t0.l_suppkey == t1.ps_suppkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_name': t1.n_name, 'name_16': t1.r_name, 'nname': t0.n_name, 'o_orderdate': t0.o_orderdate, 'r_name': t0.r_name})
33
JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=PLURAL_ACCESS, columns={'l_partkey': t1.l_partkey, 'l_suppkey': t1.l_suppkey, 'n_name': t0.n_name, 'o_orderdate': t0.o_orderdate, 'r_name': t0.r_name})
44
JOIN(condition=t0.c_custkey == t1.o_custkey, type=INNER, cardinality=PLURAL_FILTER, columns={'n_name': t0.n_name, 'o_orderdate': t1.o_orderdate, 'o_orderkey': t1.o_orderkey, 'r_name': t0.r_name})
55
JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=PLURAL_ACCESS, columns={'c_custkey': t1.c_custkey, 'n_name': t0.n_name, 'r_name': t0.r_name})

0 commit comments

Comments
 (0)