Skip to content

Commit 00d6c09

Browse files
authored
Fixing bug with complex combinations of PARTITION and CROSS (#406)
The bug originated in a pattern such as this: `A.PARTITION(...).CROSS(B.PARTITION(...)).PARTITION(...)`, where the final `PARTITION` did not work due to this error message: `PyDoughUnqualifiedException: Unsupported collection node for `split_partition_ancestry`: UnqualifiedCross`. This task fixes that bug so the afformentioned case works, and adds test for it and several other uses of `x.CROSS(y)` where `x` and `y` contain `PARTITION`
1 parent 9a0126f commit 00d6c09

16 files changed

+472
-27
lines changed

pydough/conversion/filter_pushdown.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ def visit_join(self, join: Join) -> RelationalNode:
187187
cardinality: JoinCardinality = join.cardinality
188188
new_inputs: list[RelationalNode] = []
189189

190+
# If the join type is LEFT or SEMI but the condition is TRUE, convert it
191+
# to an INNER join.
192+
if (
193+
isinstance(join.condition, LiteralExpression)
194+
and bool(join.condition.value)
195+
and join.join_type in (JoinType.LEFT, JoinType.SEMI)
196+
):
197+
join_type = JoinType.INNER
198+
190199
# If the join is a LEFT join, deduce whether the filters above it can
191200
# turn it into an INNER join.
192201
if join.join_type == JoinType.LEFT and len(self.filters) > 0:

pydough/conversion/hybrid_decorrelater.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self) -> None:
4343
self.children_indices: list[int] = []
4444

4545
def make_decorrelate_parent(
46-
self, hybrid: HybridTree, child_idx: int, min_steps: int, max_steps: int
46+
self, hybrid: HybridTree, child_idx: int, max_steps: int
4747
) -> tuple[HybridTree, int, int]:
4848
"""
4949
Creates a snapshot of the ancestry of the hybrid tree that contains
@@ -55,9 +55,6 @@ def make_decorrelate_parent(
5555
in the de-correlation of a correlated child.
5656
`child_idx`: The index of the correlated child of hybrid that the
5757
snapshot is being created to aid in the de-correlation of.
58-
`min_steps`: The index of the last pipeline operator that
59-
needs to be included in the snapshot in order for the child to be
60-
derivable.
6158
`max_steps`: The index of the first pipeline operator that cannot
6259
occur because it depends on the correlated child.
6360
@@ -83,7 +80,6 @@ def make_decorrelate_parent(
8380
result = self.make_decorrelate_parent(
8481
hybrid.parent,
8582
-1,
86-
len(hybrid.parent.pipeline) - 1,
8783
len(hybrid.parent.pipeline),
8884
)
8985
return result[0], result[1] + 1, 1
@@ -102,7 +98,7 @@ def make_decorrelate_parent(
10298
original_max_steps: int = max_steps
10399
new_correlated_children: set[int] = set()
104100
for idx, child in enumerate(new_hybrid.children):
105-
if child.min_steps < min_steps and child.max_steps <= original_max_steps:
101+
if child.max_steps < original_max_steps:
106102
new_children.append(child)
107103
if idx in hybrid.correlated_children:
108104
new_correlated_children.add(len(new_children) - 1)
@@ -510,7 +506,6 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree:
510506
self.make_decorrelate_parent(
511507
original_parent,
512508
child_idx,
513-
hybrid.children[child_idx].min_steps,
514509
hybrid.children[child_idx].max_steps,
515510
)
516511
)

pydough/unqualified/qualification.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def split_partition_ancestry(
905905
| UnqualifiedSingular()
906906
| UnqualifiedPartition()
907907
| UnqualifiedBest()
908+
| UnqualifiedCross()
908909
):
909910
parent: UnqualifiedNode = node._parcel[0]
910911
new_ancestry, new_child, ancestry_names = self.split_partition_ancestry(
@@ -967,6 +968,8 @@ def split_partition_ancestry(
967968
build_node[0] = UnqualifiedSingular(build_node[0], *node._parcel[1:])
968969
case UnqualifiedBest():
969970
build_node[0] = UnqualifiedBest(build_node[0], *node._parcel[1:])
971+
case UnqualifiedCross():
972+
build_node[0] = UnqualifiedCross(build_node[0], *node._parcel[1:])
970973
case _:
971974
# Any other unqualified node would mean something is malformed.
972975
raise PyDoughUnqualifiedException(

tests/test_pipeline_defog_custom.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,138 @@ def get_day_of_week(
14801480
),
14811481
id="window_sliding_frame_relsum",
14821482
),
1483+
pytest.param(
1484+
PyDoughPandasTest(
1485+
"exchanges = tickers.PARTITION(name='exchanges', by=exchange).CALCULATE(original_exchange=exchange)\n"
1486+
"states = customers.PARTITION(name='states', by=state).CALCULATE(state)\n"
1487+
"result = ("
1488+
" exchanges"
1489+
" .CROSS(states)"
1490+
" .CALCULATE("
1491+
" state,"
1492+
" exchange=original_exchange,"
1493+
" n=COUNT(customers.transactions_made.WHERE(ticker.exchange == original_exchange)),"
1494+
")"
1495+
".ORDER_BY(state.ASC(), exchange.ASC())"
1496+
")",
1497+
"Broker",
1498+
lambda: pd.DataFrame(
1499+
{
1500+
"state": ["CA"] * 4
1501+
+ ["FL"] * 4
1502+
+ ["NJ"] * 4
1503+
+ ["NY"] * 4
1504+
+ ["TX"] * 4,
1505+
"exchange": ["NASDAQ", "NYSE", "NYSE Arca", "Vanguard"] * 5,
1506+
"n": [
1507+
12,
1508+
7,
1509+
1,
1510+
0,
1511+
5,
1512+
3,
1513+
0,
1514+
0,
1515+
4,
1516+
0,
1517+
0,
1518+
0,
1519+
10,
1520+
3,
1521+
0,
1522+
0,
1523+
8,
1524+
2,
1525+
1,
1526+
0,
1527+
],
1528+
}
1529+
),
1530+
"part_cross_part_a",
1531+
),
1532+
id="part_cross_part_a",
1533+
),
1534+
pytest.param(
1535+
PyDoughPandasTest(
1536+
"states = customers.PARTITION(name='states', by=state).CALCULATE(original_state=state)\n"
1537+
"months = transactions.WHERE((YEAR(date_time) == 2023)).CALCULATE(month=DATETIME(date_time, 'start of month')).PARTITION(name='months', by=month).CALCULATE(month)\n"
1538+
"result = ("
1539+
" states"
1540+
" .CROSS(months)"
1541+
" .CALCULATE("
1542+
" state=original_state,"
1543+
" month_of_year=month,"
1544+
" n=RELSUM(COUNT(transactions.WHERE(customer.state == original_state)), per='states', by=month.ASC(), cumulative=True),"
1545+
")"
1546+
".ORDER_BY(state.ASC(), month_of_year.ASC())"
1547+
")",
1548+
"Broker",
1549+
lambda: pd.DataFrame(
1550+
{
1551+
"state": ["CA"] * 4
1552+
+ ["FL"] * 4
1553+
+ ["NJ"] * 4
1554+
+ ["NY"] * 4
1555+
+ ["TX"] * 4,
1556+
"months": [
1557+
"2023-01-01",
1558+
"2023-02-01",
1559+
"2023-03-01",
1560+
"2023-04-01",
1561+
]
1562+
* 5,
1563+
"n": [
1564+
0,
1565+
0,
1566+
2,
1567+
12,
1568+
0,
1569+
0,
1570+
0,
1571+
5,
1572+
3,
1573+
3,
1574+
3,
1575+
3,
1576+
0,
1577+
2,
1578+
2,
1579+
9,
1580+
0,
1581+
0,
1582+
0,
1583+
8,
1584+
],
1585+
}
1586+
),
1587+
"part_cross_part_b",
1588+
),
1589+
id="part_cross_part_b",
1590+
),
1591+
pytest.param(
1592+
PyDoughPandasTest(
1593+
"states_collection = customers.PARTITION(name='states', by=state).CALCULATE(original_state=state)\n"
1594+
"months_collection = transactions.WHERE((YEAR(date_time) == 2023)).CALCULATE(month=DATETIME(date_time, 'start of month')).PARTITION(name='months', by=month).CALCULATE(month)\n"
1595+
"result = ("
1596+
" states_collection"
1597+
" .CROSS(months_collection)"
1598+
" .CALCULATE(n=COUNT(transactions.WHERE(customer.state == original_state)))"
1599+
" .PARTITION(name='s', by=original_state)"
1600+
" .CALCULATE("
1601+
" state=original_state,"
1602+
" max_n=MAX(months.n),"
1603+
"))",
1604+
"Broker",
1605+
lambda: pd.DataFrame(
1606+
{
1607+
"state": ["CA", "FL", "NJ", "NY", "TX"],
1608+
"max_n": [10, 5, 3, 7, 8],
1609+
}
1610+
),
1611+
"part_cross_part_c",
1612+
),
1613+
id="part_cross_part_c",
1614+
),
14831615
pytest.param(
14841616
PyDoughPandasTest(
14851617
agg_simplification_1,

tests/test_plan_refsols/deep_best_analysis.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
ROOT(columns=[('r_name', r_name), ('n_name', n_name), ('c_key', c_custkey), ('c_bal', c_acctbal), ('cr_bal', account_balance_21), ('s_key', s_suppkey), ('p_key', ps_partkey), ('p_qty', ps_availqty), ('cg_key', key_54)], orderings=[(n_name):asc_first], limit=10:numeric)
2-
JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_21': t0.account_balance_21, 'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'key_54': t1.c_custkey, 'n_name': t0.n_name, 'ps_availqty': t0.ps_availqty, 'ps_partkey': t0.ps_partkey, 'r_name': t0.r_name, 's_suppkey': t0.s_suppkey})
3-
JOIN(condition=t0.r_regionkey == t1.r_regionkey & t0.n_nationkey == t1.n_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_21': t0.account_balance_21, 'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'ps_availqty': t1.ps_availqty, 'ps_partkey': t1.ps_partkey, 'r_name': t0.r_name, 's_suppkey': t1.s_suppkey})
4-
JOIN(condition=t0.r_regionkey == t1.r_regionkey & t0.n_nationkey == t1.n_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_21': t1.c_acctbal, 'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey})
1+
ROOT(columns=[('r_name', r_name), ('n_name', n_name), ('c_key', key_5), ('c_bal', c_acctbal), ('cr_bal', account_balance_13), ('s_key', s_suppkey), ('p_key', ps_partkey), ('p_qty', ps_availqty), ('cg_key', c_custkey)], orderings=[(n_name):asc_first], limit=10:numeric)
2+
JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_13': t0.account_balance_13, 'c_acctbal': t0.c_acctbal, 'c_custkey': t1.c_custkey, 'key_5': t0.c_custkey, 'n_name': t0.n_name, 'ps_availqty': t0.ps_availqty, 'ps_partkey': t0.ps_partkey, 'r_name': t0.r_name, 's_suppkey': t0.s_suppkey})
3+
JOIN(condition=t0.r_regionkey == t1.r_regionkey & t0.n_nationkey == t1.n_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_13': t0.account_balance_13, 'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'ps_availqty': t1.ps_availqty, 'ps_partkey': t1.ps_partkey, 'r_name': t0.r_name, 's_suppkey': t1.s_suppkey})
4+
JOIN(condition=t0.r_regionkey == t1.r_regionkey & t0.n_nationkey == t1.n_nationkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'account_balance_13': t1.c_acctbal, 'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey})
55
JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'c_acctbal': t1.c_acctbal, 'c_custkey': t1.c_custkey, 'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey})
66
JOIN(condition=t0.r_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_ACCESS, columns={'n_name': t1.n_name, 'n_nationkey': t1.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey})
77
SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey})
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
ROOT(columns=[('state', sbCustState), ('exchange', sbTickerExchange), ('n', DEFAULT_TO(sum_n_rows, 0:numeric))], orderings=[(sbCustState):asc_first, (sbTickerExchange):asc_first])
2+
AGGREGATE(keys={'sbCustState': sbCustState, 'sbTickerExchange': sbTickerExchange}, aggregations={'sum_n_rows': SUM(n_rows)})
3+
JOIN(condition=t0.sbTickerExchange == t1.sbTickerExchange & t0.sbCustId == t1.sbCustId, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_rows': t1.n_rows, 'sbCustState': t0.sbCustState, 'sbTickerExchange': t0.sbTickerExchange})
4+
JOIN(condition=True:bool, type=INNER, cardinality=PLURAL_ACCESS, columns={'sbCustId': t1.sbCustId, 'sbCustState': t1.sbCustState, 'sbTickerExchange': t0.sbTickerExchange})
5+
AGGREGATE(keys={'sbTickerExchange': sbTickerExchange}, aggregations={})
6+
SCAN(table=main.sbTicker, columns={'sbTickerExchange': sbTickerExchange})
7+
SCAN(table=main.sbCustomer, columns={'sbCustId': sbCustId, 'sbCustState': sbCustState})
8+
AGGREGATE(keys={'sbCustId': sbCustId, 'sbTickerExchange': sbTickerExchange}, aggregations={'n_rows': COUNT()})
9+
JOIN(condition=t0.sbTxTickerId == t1.sbTickerId & t1.sbTickerExchange == t0.sbTickerExchange, type=INNER, cardinality=SINGULAR_FILTER, columns={'sbCustId': t0.sbCustId, 'sbTickerExchange': t0.sbTickerExchange})
10+
JOIN(condition=t0.sbCustId == t1.sbTxCustId, type=INNER, cardinality=PLURAL_FILTER, columns={'sbCustId': t0.sbCustId, 'sbTickerExchange': t0.sbTickerExchange, 'sbTxTickerId': t1.sbTxTickerId})
11+
JOIN(condition=True:bool, type=INNER, cardinality=PLURAL_ACCESS, columns={'sbCustId': t1.sbCustId, 'sbTickerExchange': t0.sbTickerExchange})
12+
AGGREGATE(keys={'sbTickerExchange': sbTickerExchange}, aggregations={})
13+
SCAN(table=main.sbTicker, columns={'sbTickerExchange': sbTickerExchange})
14+
SCAN(table=main.sbCustomer, columns={'sbCustId': sbCustId})
15+
SCAN(table=main.sbTransaction, columns={'sbTxCustId': sbTxCustId, 'sbTxTickerId': sbTxTickerId})
16+
SCAN(table=main.sbTicker, columns={'sbTickerExchange': sbTickerExchange, 'sbTickerId': sbTickerId})
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
ROOT(columns=[('state', sbCustState), ('month_of_year', month), ('n', RELSUM(args=[DEFAULT_TO(n_rows, 0:numeric)], partition=[sbCustState], order=[(month):asc_last], cumulative=True))], orderings=[(sbCustState):asc_first, (month):asc_first])
2+
JOIN(condition=t0.sbCustState == t1.sbCustState & t0.month == t1.month, type=LEFT, cardinality=SINGULAR_FILTER, columns={'month': t0.month, 'n_rows': t1.n_rows, 'sbCustState': t0.sbCustState})
3+
JOIN(condition=True:bool, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t1.month, 'sbCustState': t0.sbCustState})
4+
AGGREGATE(keys={'sbCustState': sbCustState}, aggregations={})
5+
SCAN(table=main.sbCustomer, columns={'sbCustState': sbCustState})
6+
AGGREGATE(keys={'month': DATETIME(sbTxDateTime, 'start of month':string)}, aggregations={})
7+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime})
8+
SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime})
9+
AGGREGATE(keys={'month': month, 'sbCustState': sbCustState}, aggregations={'n_rows': COUNT()})
10+
JOIN(condition=t0.sbTxCustId == t1.sbCustId & t1.sbCustState == t0.sbCustState, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t0.month, 'sbCustState': t0.sbCustState})
11+
JOIN(condition=t0.month == DATETIME(t1.sbTxDateTime, 'start of month':string), type=INNER, cardinality=PLURAL_FILTER, columns={'month': t0.month, 'sbCustState': t0.sbCustState, 'sbTxCustId': t1.sbTxCustId})
12+
JOIN(condition=True:bool, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t1.month, 'sbCustState': t0.sbCustState})
13+
AGGREGATE(keys={'sbCustState': sbCustState}, aggregations={})
14+
SCAN(table=main.sbCustomer, columns={'sbCustState': sbCustState})
15+
AGGREGATE(keys={'month': DATETIME(sbTxDateTime, 'start of month':string)}, aggregations={})
16+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime})
17+
SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime})
18+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxCustId': sbTxCustId, 'sbTxDateTime': sbTxDateTime})
19+
SCAN(table=main.sbTransaction, columns={'sbTxCustId': sbTxCustId, 'sbTxDateTime': sbTxDateTime})
20+
SCAN(table=main.sbCustomer, columns={'sbCustId': sbCustId, 'sbCustState': sbCustState})
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
ROOT(columns=[('state', sbCustState), ('max_n', max_n)], orderings=[])
2+
AGGREGATE(keys={'sbCustState': sbCustState}, aggregations={'max_n': MAX(DEFAULT_TO(n_rows, 0:numeric))})
3+
JOIN(condition=t0.sbCustState == t1.sbCustState & t0.month == t1.month, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_rows': t1.n_rows, 'sbCustState': t0.sbCustState})
4+
JOIN(condition=True:bool, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t1.month, 'sbCustState': t0.sbCustState})
5+
AGGREGATE(keys={'sbCustState': sbCustState}, aggregations={})
6+
SCAN(table=main.sbCustomer, columns={'sbCustState': sbCustState})
7+
AGGREGATE(keys={'month': DATETIME(sbTxDateTime, 'start of month':string)}, aggregations={})
8+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime})
9+
SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime})
10+
AGGREGATE(keys={'month': month, 'sbCustState': sbCustState}, aggregations={'n_rows': COUNT()})
11+
JOIN(condition=t0.sbTxCustId == t1.sbCustId & t1.sbCustState == t0.sbCustState, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t0.month, 'sbCustState': t0.sbCustState})
12+
JOIN(condition=t0.month == DATETIME(t1.sbTxDateTime, 'start of month':string), type=INNER, cardinality=PLURAL_FILTER, columns={'month': t0.month, 'sbCustState': t0.sbCustState, 'sbTxCustId': t1.sbTxCustId})
13+
JOIN(condition=True:bool, type=INNER, cardinality=SINGULAR_FILTER, columns={'month': t1.month, 'sbCustState': t0.sbCustState})
14+
AGGREGATE(keys={'sbCustState': sbCustState}, aggregations={})
15+
SCAN(table=main.sbCustomer, columns={'sbCustState': sbCustState})
16+
AGGREGATE(keys={'month': DATETIME(sbTxDateTime, 'start of month':string)}, aggregations={})
17+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime})
18+
SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime})
19+
FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxCustId': sbTxCustId, 'sbTxDateTime': sbTxDateTime})
20+
SCAN(table=main.sbTransaction, columns={'sbTxCustId': sbTxCustId, 'sbTxDateTime': sbTxDateTime})
21+
SCAN(table=main.sbCustomer, columns={'sbCustId': sbCustId, 'sbCustState': sbCustState})

0 commit comments

Comments
 (0)