Skip to content

Commit cb6e86a

Browse files
authored
Fixing CROSS bug with partition edge cases (#423)
Fixing recently exposed bugs with how `CROSS` handles certain edge cases involving partitions. Resolutions for this bug: - Adjusted the `all_terms` clause of a GlobalContext to include ancestrally inherited terms - Adjusted qualification to remove the `is_cross` term and instead qualify the RHS of the cross with an _augmented_ version of the LHS qualified answer as the context, with a new GlobalContext appended to it - Adjusted the relational conversion of the children of a partition clause to avoid deleting join keys that are NOT original grouping keys Also added some new tests to trigger patterns related to this bug or the fixes.
1 parent 5be84e0 commit cb6e86a

17 files changed

+532
-100
lines changed

pydough/conversion/relational_converter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -669,11 +669,20 @@ def handle_children(
669669
join_keys: list[tuple[HybridExpr, HybridExpr]] | None = (
670670
child.subtree.join_keys
671671
)
672+
673+
# If handling the child of a partition, remove any join keys
674+
# where the LHS is a simple reference, since those keys are
675+
# guaranteed to be present due to the partitioning.
672676
if (
673677
isinstance(hybrid.pipeline[pipeline_idx], HybridPartition)
674678
and child_idx == 0
675679
):
676-
join_keys = None
680+
new_join_keys: list[tuple[HybridExpr, HybridExpr]] = []
681+
if join_keys is not None:
682+
for lhs_key, rhs_key in join_keys:
683+
if not isinstance(lhs_key, HybridRefExpr):
684+
new_join_keys.append((lhs_key, rhs_key))
685+
join_keys = new_join_keys if len(new_join_keys) > 0 else None
677686
child_expr: HybridExpr
678687
match child.connection_type:
679688
case (
@@ -1461,11 +1470,7 @@ def optimize_relational_tree(
14611470
# Run projection pullup to move projections as far up the tree as possible.
14621471
# This is done as soon as possible to make joins redundant if they only
14631472
# exist to compute a scalar projection and then link it with the data.
1464-
# print()
1465-
# print(root.to_tree_string())
14661473
root = confirm_root(pullup_projections(root))
1467-
# print()
1468-
# print(root.to_tree_string())
14691474

14701475
# Push filters down as far as possible
14711476
root = confirm_root(push_filters(root, configs))

pydough/qdag/collections/global_context.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG
1515
from pydough.qdag.errors import PyDoughQDAGException
16-
from pydough.qdag.expressions import CollationExpression
16+
from pydough.qdag.expressions import BackReferenceExpression, CollationExpression
1717

1818
from .collection_qdag import PyDoughCollectionQDAG
1919
from .collection_tree_form import CollectionTreeForm
@@ -33,6 +33,7 @@ def __init__(
3333
self._graph = graph
3434
self._collections: dict[str, PyDoughCollectionQDAG] = {}
3535
self._ancestral_mapping: dict[str, int] = {}
36+
self._all_terms: set[str] = set()
3637
# If this collection has an ancestor, inherit its ancestors
3738
# and update depths, to preserve sub-collection context.
3839
# This ensures that downstream operations have the correct hierarchy.
@@ -44,11 +45,13 @@ def __init__(
4445
f"Name {name!r} conflicts with a collection name in the graph {graph.name!r}"
4546
)
4647
else:
48+
self._all_terms.add(name)
4749
self._ancestral_mapping[name] = level + 1
4850
for collection_name in graph.get_collection_names():
4951
meta = graph.get_collection(collection_name)
5052
assert isinstance(meta, CollectionMetadata)
5153
self._collections[collection_name] = TableCollection(meta, self)
54+
self._all_terms.add(collection_name)
5255

5356
@property
5457
def graph(self) -> GraphMetadata:
@@ -98,7 +101,7 @@ def inherited_downstreamed_terms(self) -> set[str]:
98101

99102
@property
100103
def all_terms(self) -> set[str]:
101-
return set(self.collections)
104+
return self._all_terms
102105

103106
@property
104107
def ordering(self) -> list[CollationExpression] | None:
@@ -115,11 +118,22 @@ def get_expression_position(self, expr_name: str) -> int:
115118
raise PyDoughQDAGException(f"Cannot call get_expression_position on {self!r}")
116119

117120
def get_term(self, term_name: str) -> PyDoughQDAG:
118-
if term_name not in self.collections:
121+
if term_name in self.ancestral_mapping:
122+
# Verify that the ancestor name is not also a name in the current
123+
# context.
124+
if term_name in self.collections:
125+
raise PyDoughQDAGException(
126+
f"Cannot have term name {term_name!r} used in an ancestor of collection {self!r}"
127+
)
128+
# Create a back-reference to the ancestor term.
129+
return BackReferenceExpression(
130+
self, term_name, self.ancestral_mapping[term_name]
131+
)
132+
elif term_name in self.collections:
133+
return self.collections[term_name]
134+
else:
119135
raise PyDoughQDAGException(self.name_mismatch_error(term_name))
120136

121-
return self.collections[term_name]
122-
123137
@property
124138
def standalone_string(self) -> str:
125139
return self.graph.name
@@ -157,4 +171,8 @@ def to_tree_form(self, is_last: bool) -> CollectionTreeForm:
157171
return self.to_tree_form_isolated(is_last)
158172

159173
def equals(self, other: object) -> bool:
160-
return isinstance(other, GlobalContext) and self.graph == other.graph
174+
return (
175+
isinstance(other, GlobalContext)
176+
and self.graph == other.graph
177+
and self.ancestor_context == other.ancestor_context
178+
)

0 commit comments

Comments
 (0)