From b48d3f5d96560820e41afcd0097f701835f0c45e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Thu, 14 Aug 2025 16:39:24 -0400 Subject: [PATCH 01/16] Initial quarter simplification --- .../conversion/relational_simplification.py | 222 ++++++++++++++++++ tests/test_plan_refsols/common_prefix_n.txt | 2 +- tests/test_plan_refsols/common_prefix_o.txt | 2 +- tests/test_plan_refsols/correl_35.txt | 2 +- tests/test_plan_refsols/correl_36.txt | 2 +- tests/test_plan_refsols/tpch_q10.txt | 2 +- tests/test_plan_refsols/tpch_q4.txt | 2 +- tests/test_sql_refsols/correl_35_sqlite.sql | 17 +- tests/test_sql_refsols/correl_36_sqlite.sql | 17 +- .../defog_broker_gen5_ansi.sql | 2 +- .../defog_broker_gen5_sqlite.sql | 17 +- tests/test_sql_refsols/tpch_q10_ansi.sql | 2 +- tests/test_sql_refsols/tpch_q10_sqlite.sql | 17 +- tests/test_sql_refsols/tpch_q4_ansi.sql | 2 +- tests/test_sql_refsols/tpch_q4_sqlite.sql | 17 +- 15 files changed, 241 insertions(+), 84 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 10cecd60e..d2905d729 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -35,6 +35,7 @@ from pydough.relational.rel_util import ( add_input_name, ) +from pydough.types import ArrayType, NumericType @dataclass @@ -279,6 +280,200 @@ def visit_window_expression( arg_predicates.reverse() return self.simplify_window_call(new_window, arg_predicates) + def quarter_month_array(self, quarter: int) -> RelationalExpression: + """ + TODO + """ + assert 1 <= quarter <= 4 + month_arr: list[int] = [3 * (quarter - 1) + i + 1 for i in range(3)] + return LiteralExpression(month_arr, ArrayType(NumericType())) + + def switch_operator( + self, expr: CallExpression, op: pydop.PyDoughExpressionOperator + ) -> RelationalExpression: + """ + TODO + """ + return CallExpression(op, expr.data_type, expr.inputs) + + def keep_if_not_null( + self, source: RelationalExpression, expr: RelationalExpression + ) -> RelationalExpression: + """ + TODO + """ + source_not_null: RelationalExpression = CallExpression( + pydop.PRESENT, source.data_type, [source] + ) + return CallExpression(pydop.KEEP_IF, expr.data_type, [expr, source_not_null]) + + def simplify_function_literal_comparison( + self, + expr: RelationalExpression, + op: pydop.PyDoughOperator, + func_expr: CallExpression, + lit_expr: LiteralExpression, + ) -> RelationalExpression: + """ + TODO + """ + assert op in (pydop.EQU, pydop.NEQ, pydop.GRT, pydop.GEQ, pydop.LET, pydop.LEQ) + result: RelationalExpression = expr + conditional_true: RelationalExpression = self.keep_if_not_null( + func_expr.inputs[0], LiteralExpression(True, expr.data_type) + ) + conditional_false: RelationalExpression = self.keep_if_not_null( + func_expr.inputs[0], LiteralExpression(False, expr.data_type) + ) + match (op, func_expr.op, lit_expr.data_type): + case (pydop.EQU, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) == 1 <=> ISIN(MONTH(x), [1, 2, 3]) + if lit_expr.value in (1, 2, 3, 4): + result = CallExpression( + pydop.ISIN, + expr.data_type, + [ + self.switch_operator(func_expr, pydop.MONTH), + self.quarter_month_array(lit_expr.value), + ], + ) + # QUARTER(x) == 3 <=> KEEP_IF(False, PRESENT(x)) + else: + result = conditional_false + case (pydop.NEQ, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) == 4 <=> NOT(ISIN(MONTH(x), [10, 11, 12])) + if lit_expr.value in (1, 2, 3, 4): + result = CallExpression( + pydop.NOT, + expr.data_type, + [ + CallExpression( + pydop.ISIN, + expr.data_type, + [ + self.switch_operator(func_expr, pydop.MONTH), + self.quarter_month_array(lit_expr.value), + ], + ) + ], + ) + # QUARTER(x) != 0 <=> KEEP_IF(True, PRESENT(x)) + else: + result = conditional_true + case (pydop.LET, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) < 4 <=> MONTH(X) < 9 + if lit_expr.value in (2, 3, 4): + result = CallExpression( + pydop.LET, + expr.data_type, + [ + func_expr.inputs[0], + LiteralExpression((lit_expr.value * 3) - 2, NumericType()), + ], + ) + # QUARTER(x) < 1 <=> KEEP_IF(False, PRESENT(x)) + elif lit_expr.value < 2: + result = conditional_false + # QUARTER(x) < 5 <=> KEEP_IF(True, PRESENT(x)) + elif lit_expr.value > 4: + result = conditional_true + case (pydop.LEQ, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) <= 2 <=> MONTH(X) <= 6 + if lit_expr.value in (1, 2, 3): + result = CallExpression( + pydop.LEQ, + expr.data_type, + [ + func_expr.inputs[0], + LiteralExpression(lit_expr.value * 3, NumericType()), + ], + ) + # QUARTER(x) <= 0 <=> KEEP_IF(False, PRESENT(x)) + elif lit_expr.value < 1: + result = conditional_false + # QUARTER(x) <= 4 <=> KEEP_IF(True, PRESENT(x)) + elif lit_expr.value > 3: + result = conditional_true + case (pydop.GRT, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) > 1 <=> MONTH(X) > 3 + if lit_expr.value in (1, 2, 3): + result = CallExpression( + pydop.LET, + expr.data_type, + [ + func_expr.inputs[0], + LiteralExpression(lit_expr.value * 3, NumericType()), + ], + ) + # QUARTER(x) > 0 <=> KEEP_IF(True, PRESENT(x)) + elif lit_expr.value < 1: + result = conditional_true + # QUARTER(x) > 4 <=> KEEP_IF(False, PRESENT(x)) + elif lit_expr.value > 3: + result = conditional_false + case (pydop.GEQ, pydop.QUARTER, NumericType()) if isinstance( + lit_expr.value, int + ): + # QUARTER(x) >= 3 <=> MONTH(X) >= 7 + if lit_expr.value in (2, 3, 4): + result = CallExpression( + pydop.GEQ, + expr.data_type, + [ + func_expr.inputs[0], + LiteralExpression((lit_expr.value * 3) - 2, NumericType()), + ], + ) + # QUARTER(x) >= 1 <=> KEEP_IF(True, PRESENT(x)) + elif lit_expr.value < 2: + result = conditional_true + # QUARTER(x) >= 6 <=> KEEP_IF(False, PRESENT(x)) + elif lit_expr.value > 4: + result = conditional_false + # MONTH(x) > 0 <=> KEEP_IF(True, PRESENT(x)) (same for other units) + # MONTH(x) != -3 <=> KEEP_IF(True, PRESENT(x)) (same for other units) + case ( + pydop.GRT | pydop.NEQ, + pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value < 1: + result = conditional_true + # MONTH(x) >= 1 <=> KEEP_IF(True, PRESENT(x)) (same for other units) + case ( + pydop.GEQ, + pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value <= 1: + result = conditional_true + # MONTH(x) < 1 <=> KEEP_IF(False, PRESENT(x)) (same for other units) + case ( + pydop.LET, + pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value <= 1: + result = conditional_false + # MONTH(x) <= 0 <=> KEEP_IF(False, PRESENT(x)) (same for other units) + case ( + pydop.LEQ, + pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value < 1: + result = conditional_false + case _: + # Fall back to the original expression by default. + pass + return result + def simplify_function_call( self, expr: CallExpression, @@ -673,6 +868,33 @@ def simplify_function_call( ) and isinstance(y, (int, float, str, bool)): output_expr = LiteralExpression(x >= y, expr.data_type) # type: ignore + # In cases where we do FUNC(x) cmp LIT, attempt additional + # simplifications. + case (CallExpression(), _, LiteralExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, expr.op, expr.inputs[0], expr.inputs[1] + ) + case (LiteralExpression(), pydop.EQU | pydop.NEQ, CallExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, expr.op, expr.inputs[1], expr.inputs[0] + ) + case (LiteralExpression(), pydop.GRT, CallExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, pydop.LET, expr.inputs[1], expr.inputs[0] + ) + case (LiteralExpression(), pydop.GEQ, CallExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, pydop.LEQ, expr.inputs[1], expr.inputs[0] + ) + case (LiteralExpression(), pydop.LET, CallExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, pydop.GRT, expr.inputs[1], expr.inputs[0] + ) + case (LiteralExpression(), pydop.LEQ, CallExpression()): + output_expr = self.simplify_function_literal_comparison( + expr, pydop.GEQ, expr.inputs[1], expr.inputs[0] + ) + case _: # All other cases remain non-simplified. pass diff --git a/tests/test_plan_refsols/common_prefix_n.txt b/tests/test_plan_refsols/common_prefix_n.txt index 3faf6e877..c5669c068 100644 --- a/tests/test_plan_refsols/common_prefix_n.txt +++ b/tests/test_plan_refsols/common_prefix_n.txt @@ -2,7 +2,7 @@ ROOT(columns=[('key', o_orderkey), ('order_date', o_orderdate), ('n_elements', D FILTER(condition=DEFAULT_TO(n_rows, 0:numeric) > DEFAULT_TO(ndistinct_n_name, 0:numeric), columns={'max_s_acctbal': max_s_acctbal, 'n_rows': n_rows, 'ndistinct_n_name': ndistinct_n_name, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'sum_agg_11': sum_agg_11, 'sum_p_retailprice': sum_p_retailprice}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'max_s_acctbal': t0.max_s_acctbal, 'n_rows': t0.n_rows, 'ndistinct_n_name': t1.ndistinct_n_name, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'sum_agg_11': t0.sum_agg_11, 'sum_p_retailprice': t0.sum_p_retailprice}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=SINGULAR_FILTER, columns={'max_s_acctbal': t1.max_s_acctbal, 'n_rows': t1.n_rows, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'sum_agg_11': t1.sum_agg_11, 'sum_p_retailprice': t1.sum_p_retailprice}) - FILTER(condition=QUARTER(o_orderdate) == 4:numeric & YEAR(o_orderdate) == 1996:numeric, columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) + FILTER(condition=YEAR(o_orderdate) == 1996:numeric & ISIN(MONTH(o_orderdate), [10, 11, 12]:array[numeric]), columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) SCAN(table=tpch.ORDERS, columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) AGGREGATE(keys={'l_orderkey': l_orderkey}, aggregations={'max_s_acctbal': MAX(s_acctbal), 'n_rows': COUNT(), 'sum_agg_11': SUM(agg_11), 'sum_p_retailprice': SUM(p_retailprice)}) JOIN(condition=t0.l_partkey == t1.p_partkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'agg_11': t1.agg_11, 'l_orderkey': t0.l_orderkey, 'p_retailprice': t0.p_retailprice, 's_acctbal': t0.s_acctbal}) diff --git a/tests/test_plan_refsols/common_prefix_o.txt b/tests/test_plan_refsols/common_prefix_o.txt index 42ae08339..8d808b53a 100644 --- a/tests/test_plan_refsols/common_prefix_o.txt +++ b/tests/test_plan_refsols/common_prefix_o.txt @@ -2,7 +2,7 @@ ROOT(columns=[('key', o_orderkey), ('order_date', o_orderdate), ('n_elements', D FILTER(condition=DEFAULT_TO(sum_n_rows, 0:numeric) > DEFAULT_TO(ndistinct_n_name, 0:numeric), columns={'max_s_acctbal': max_s_acctbal, 'n_small_parts': sum_sum_agg_5, 'ndistinct_n_name': ndistinct_n_name, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'sum_n_rows': sum_n_rows, 'sum_sum_p_retailprice': sum_sum_p_retailprice}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'max_s_acctbal': t0.max_s_acctbal, 'ndistinct_n_name': t1.ndistinct_n_name, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'sum_n_rows': t0.sum_n_rows, 'sum_sum_agg_5': t0.sum_sum_agg_5, 'sum_sum_p_retailprice': t0.sum_sum_p_retailprice}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=SINGULAR_FILTER, columns={'max_s_acctbal': t1.max_s_acctbal, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'sum_n_rows': t1.sum_n_rows, 'sum_sum_agg_5': t1.sum_sum_agg_5, 'sum_sum_p_retailprice': t1.sum_sum_p_retailprice}) - FILTER(condition=QUARTER(o_orderdate) == 4:numeric & YEAR(o_orderdate) == 1996:numeric, columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) + FILTER(condition=YEAR(o_orderdate) == 1996:numeric & ISIN(MONTH(o_orderdate), [10, 11, 12]:array[numeric]), columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) SCAN(table=tpch.ORDERS, columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) FILTER(condition=sum_sum_agg_5 > 0:numeric, columns={'l_orderkey': l_orderkey, 'max_s_acctbal': max_s_acctbal, 'sum_n_rows': sum_n_rows, 'sum_sum_agg_5': sum_sum_agg_5, 'sum_sum_p_retailprice': sum_sum_p_retailprice}) AGGREGATE(keys={'l_orderkey': l_orderkey}, aggregations={'max_s_acctbal': MAX(s_acctbal), 'sum_n_rows': SUM(n_rows), 'sum_sum_agg_5': SUM(sum_agg_5), 'sum_sum_p_retailprice': SUM(sum_p_retailprice)}) diff --git a/tests/test_plan_refsols/correl_35.txt b/tests/test_plan_refsols/correl_35.txt index ec4994e6d..406802431 100644 --- a/tests/test_plan_refsols/correl_35.txt +++ b/tests/test_plan_refsols/correl_35.txt @@ -19,6 +19,6 @@ ROOT(columns=[('n', n)], orderings=[]) SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) FILTER(condition=YEAR(o_orderdate) == 1997:numeric, columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) SCAN(table=tpch.ORDERS, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) - FILTER(condition=QUARTER(l_shipdate) == 1:numeric & YEAR(l_shipdate) == 1997:numeric, columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey}) + FILTER(condition=YEAR(l_shipdate) == 1997:numeric & ISIN(MONTH(l_shipdate), [1, 2, 3]:array[numeric]), columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey}) SCAN(table=tpch.LINEITEM, columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey, 'l_shipdate': l_shipdate}) SCAN(table=tpch.PART, columns={'p_partkey': p_partkey, 'p_type': p_type}) diff --git a/tests/test_plan_refsols/correl_36.txt b/tests/test_plan_refsols/correl_36.txt index fd87197e6..e81a5fd2b 100644 --- a/tests/test_plan_refsols/correl_36.txt +++ b/tests/test_plan_refsols/correl_36.txt @@ -28,6 +28,6 @@ ROOT(columns=[('n', n)], orderings=[]) SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) FILTER(condition=YEAR(o_orderdate) == 1997:numeric, columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) SCAN(table=tpch.ORDERS, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) - FILTER(condition=QUARTER(l_shipdate) == 1:numeric & YEAR(l_shipdate) == 1997:numeric, columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey}) + FILTER(condition=YEAR(l_shipdate) == 1997:numeric & ISIN(MONTH(l_shipdate), [1, 2, 3]:array[numeric]), columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey}) SCAN(table=tpch.LINEITEM, columns={'l_orderkey': l_orderkey, 'l_partkey': l_partkey, 'l_shipdate': l_shipdate}) SCAN(table=tpch.PART, columns={'p_partkey': p_partkey, 'p_type': p_type}) diff --git a/tests/test_plan_refsols/tpch_q10.txt b/tests/test_plan_refsols/tpch_q10.txt index 71b4e801c..6f3f32609 100644 --- a/tests/test_plan_refsols/tpch_q10.txt +++ b/tests/test_plan_refsols/tpch_q10.txt @@ -4,7 +4,7 @@ ROOT(columns=[('C_CUSTKEY', c_custkey), ('C_NAME', c_name), ('REVENUE', DEFAULT_ SCAN(table=tpch.CUSTOMER, columns={'c_acctbal': c_acctbal, 'c_address': c_address, 'c_comment': c_comment, 'c_custkey': c_custkey, 'c_name': c_name, 'c_nationkey': c_nationkey, 'c_phone': c_phone}) AGGREGATE(keys={'o_custkey': o_custkey}, aggregations={'sum_expr_1': SUM(l_extendedprice * 1:numeric - l_discount)}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=PLURAL_FILTER, columns={'l_discount': t1.l_discount, 'l_extendedprice': t1.l_extendedprice, 'o_custkey': t0.o_custkey}) - FILTER(condition=QUARTER(o_orderdate) == 4:numeric & YEAR(o_orderdate) == 1993:numeric, columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) + FILTER(condition=YEAR(o_orderdate) == 1993:numeric & ISIN(MONTH(o_orderdate), [10, 11, 12]:array[numeric]), columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) SCAN(table=tpch.ORDERS, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) FILTER(condition=l_returnflag == 'R':string, columns={'l_discount': l_discount, 'l_extendedprice': l_extendedprice, 'l_orderkey': l_orderkey}) SCAN(table=tpch.LINEITEM, columns={'l_discount': l_discount, 'l_extendedprice': l_extendedprice, 'l_orderkey': l_orderkey, 'l_returnflag': l_returnflag}) diff --git a/tests/test_plan_refsols/tpch_q4.txt b/tests/test_plan_refsols/tpch_q4.txt index 633e0d21e..43e9cd50e 100644 --- a/tests/test_plan_refsols/tpch_q4.txt +++ b/tests/test_plan_refsols/tpch_q4.txt @@ -1,7 +1,7 @@ ROOT(columns=[('O_ORDERPRIORITY', o_orderpriority), ('ORDER_COUNT', ORDER_COUNT)], orderings=[(o_orderpriority):asc_first]) AGGREGATE(keys={'o_orderpriority': o_orderpriority}, aggregations={'ORDER_COUNT': COUNT()}) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=SEMI, cardinality=SINGULAR_FILTER, columns={'o_orderpriority': t0.o_orderpriority}) - FILTER(condition=QUARTER(o_orderdate) == 3:numeric & YEAR(o_orderdate) == 1993:numeric, columns={'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) + FILTER(condition=YEAR(o_orderdate) == 1993:numeric & ISIN(MONTH(o_orderdate), [7, 8, 9]:array[numeric]), columns={'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) SCAN(table=tpch.ORDERS, columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) FILTER(condition=l_commitdate < l_receiptdate, columns={'l_orderkey': l_orderkey}) SCAN(table=tpch.LINEITEM, columns={'l_commitdate': l_commitdate, 'l_orderkey': l_orderkey, 'l_receiptdate': l_receiptdate}) diff --git a/tests/test_sql_refsols/correl_35_sqlite.sql b/tests/test_sql_refsols/correl_35_sqlite.sql index bda581395..523f30624 100644 --- a/tests/test_sql_refsols/correl_35_sqlite.sql +++ b/tests/test_sql_refsols/correl_35_sqlite.sql @@ -15,21 +15,8 @@ WITH _s1 AS ( ON CAST(STRFTIME('%Y', orders.o_orderdate) AS INTEGER) = 1997 AND customer.c_custkey = orders.o_custkey JOIN tpch.lineitem AS lineitem - ON CASE - WHEN CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) <= 3 - AND CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) >= 1 - THEN 1 - WHEN CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) <= 6 - AND CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) >= 4 - THEN 2 - WHEN CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) <= 9 - AND CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) >= 7 - THEN 3 - WHEN CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) <= 12 - AND CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) >= 10 - THEN 4 - END = 1 - AND CAST(STRFTIME('%Y', lineitem.l_shipdate) AS INTEGER) = 1997 + ON CAST(STRFTIME('%Y', lineitem.l_shipdate) AS INTEGER) = 1997 + AND CAST(STRFTIME('%m', lineitem.l_shipdate) AS INTEGER) IN (1, 2, 3) AND lineitem.l_orderkey = orders.o_orderkey GROUP BY customer.c_custkey, diff --git a/tests/test_sql_refsols/correl_36_sqlite.sql b/tests/test_sql_refsols/correl_36_sqlite.sql index 22d824e86..a02563e86 100644 --- a/tests/test_sql_refsols/correl_36_sqlite.sql +++ b/tests/test_sql_refsols/correl_36_sqlite.sql @@ -28,21 +28,8 @@ WITH _s3 AS ( AND customer.c_custkey = orders_2.o_custkey AND orders.o_orderpriority = orders_2.o_orderpriority JOIN tpch.lineitem AS lineitem_2 - ON CASE - WHEN CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) <= 3 - AND CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) >= 1 - THEN 1 - WHEN CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) <= 6 - AND CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) >= 4 - THEN 2 - WHEN CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) <= 9 - AND CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) >= 7 - THEN 3 - WHEN CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) <= 12 - AND CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) >= 10 - THEN 4 - END = 1 - AND CAST(STRFTIME('%Y', lineitem_2.l_shipdate) AS INTEGER) = 1997 + ON CAST(STRFTIME('%Y', lineitem_2.l_shipdate) AS INTEGER) = 1997 + AND CAST(STRFTIME('%m', lineitem_2.l_shipdate) AS INTEGER) IN (1, 2, 3) AND lineitem_2.l_orderkey = orders_2.o_orderkey JOIN _s3 AS _s19 ON _s19.p_partkey = lineitem_2.l_partkey AND _s19.p_type = _s3.p_type diff --git a/tests/test_sql_refsols/defog_broker_gen5_ansi.sql b/tests/test_sql_refsols/defog_broker_gen5_ansi.sql index fd67ee3ec..56503c77a 100644 --- a/tests/test_sql_refsols/defog_broker_gen5_ansi.sql +++ b/tests/test_sql_refsols/defog_broker_gen5_ansi.sql @@ -3,7 +3,7 @@ SELECT AVG(sbtxprice) AS avg_price FROM main.sbtransaction WHERE - EXTRACT(QUARTER FROM CAST(sbtxdatetime AS DATETIME)) = 1 + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (1, 2, 3) AND EXTRACT(YEAR FROM CAST(sbtxdatetime AS DATETIME)) = 2023 AND sbtxstatus = 'success' GROUP BY diff --git a/tests/test_sql_refsols/defog_broker_gen5_sqlite.sql b/tests/test_sql_refsols/defog_broker_gen5_sqlite.sql index b74c46dea..b4ada8be4 100644 --- a/tests/test_sql_refsols/defog_broker_gen5_sqlite.sql +++ b/tests/test_sql_refsols/defog_broker_gen5_sqlite.sql @@ -3,21 +3,8 @@ SELECT AVG(sbtxprice) AS avg_price FROM main.sbtransaction WHERE - CASE - WHEN CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 3 - AND CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 1 - THEN 1 - WHEN CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 6 - AND CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 4 - THEN 2 - WHEN CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 9 - AND CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 7 - THEN 3 - WHEN CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 12 - AND CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 10 - THEN 4 - END = 1 - AND CAST(STRFTIME('%Y', sbtxdatetime) AS INTEGER) = 2023 + CAST(STRFTIME('%Y', sbtxdatetime) AS INTEGER) = 2023 + AND CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (1, 2, 3) AND sbtxstatus = 'success' GROUP BY DATE(sbtxdatetime, 'start of month') diff --git a/tests/test_sql_refsols/tpch_q10_ansi.sql b/tests/test_sql_refsols/tpch_q10_ansi.sql index f0a2699a0..c9b93a958 100644 --- a/tests/test_sql_refsols/tpch_q10_ansi.sql +++ b/tests/test_sql_refsols/tpch_q10_ansi.sql @@ -8,7 +8,7 @@ WITH _s3 AS ( JOIN tpch.lineitem AS lineitem ON lineitem.l_orderkey = orders.o_orderkey AND lineitem.l_returnflag = 'R' WHERE - EXTRACT(QUARTER FROM CAST(orders.o_orderdate AS DATETIME)) = 4 + EXTRACT(MONTH FROM CAST(orders.o_orderdate AS DATETIME)) IN (10, 11, 12) AND EXTRACT(YEAR FROM CAST(orders.o_orderdate AS DATETIME)) = 1993 GROUP BY orders.o_custkey diff --git a/tests/test_sql_refsols/tpch_q10_sqlite.sql b/tests/test_sql_refsols/tpch_q10_sqlite.sql index 7e5943713..5ba3d17f4 100644 --- a/tests/test_sql_refsols/tpch_q10_sqlite.sql +++ b/tests/test_sql_refsols/tpch_q10_sqlite.sql @@ -8,21 +8,8 @@ WITH _s3 AS ( JOIN tpch.lineitem AS lineitem ON lineitem.l_orderkey = orders.o_orderkey AND lineitem.l_returnflag = 'R' WHERE - CASE - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 3 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 1 - THEN 1 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 6 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 4 - THEN 2 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 9 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 7 - THEN 3 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 12 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 10 - THEN 4 - END = 4 - AND CAST(STRFTIME('%Y', orders.o_orderdate) AS INTEGER) = 1993 + CAST(STRFTIME('%Y', orders.o_orderdate) AS INTEGER) = 1993 + AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) IN (10, 11, 12) GROUP BY orders.o_custkey ) diff --git a/tests/test_sql_refsols/tpch_q4_ansi.sql b/tests/test_sql_refsols/tpch_q4_ansi.sql index a861ccf44..a03e15af0 100644 --- a/tests/test_sql_refsols/tpch_q4_ansi.sql +++ b/tests/test_sql_refsols/tpch_q4_ansi.sql @@ -6,7 +6,7 @@ JOIN tpch.lineitem AS lineitem ON lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_orderkey = orders.o_orderkey WHERE - EXTRACT(QUARTER FROM CAST(orders.o_orderdate AS DATETIME)) = 3 + EXTRACT(MONTH FROM CAST(orders.o_orderdate AS DATETIME)) IN (7, 8, 9) AND EXTRACT(YEAR FROM CAST(orders.o_orderdate AS DATETIME)) = 1993 GROUP BY orders.o_orderpriority diff --git a/tests/test_sql_refsols/tpch_q4_sqlite.sql b/tests/test_sql_refsols/tpch_q4_sqlite.sql index f928bbd52..b416a167f 100644 --- a/tests/test_sql_refsols/tpch_q4_sqlite.sql +++ b/tests/test_sql_refsols/tpch_q4_sqlite.sql @@ -14,21 +14,8 @@ FROM tpch.orders AS orders LEFT JOIN _u_0 AS _u_0 ON _u_0._u_1 = orders.o_orderkey WHERE - CASE - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 3 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 1 - THEN 1 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 6 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 4 - THEN 2 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 9 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 7 - THEN 3 - WHEN CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) <= 12 - AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) >= 10 - THEN 4 - END = 3 - AND CAST(STRFTIME('%Y', orders.o_orderdate) AS INTEGER) = 1993 + CAST(STRFTIME('%Y', orders.o_orderdate) AS INTEGER) = 1993 + AND CAST(STRFTIME('%m', orders.o_orderdate) AS INTEGER) IN (7, 8, 9) AND NOT _u_0._u_1 IS NULL GROUP BY orders.o_orderpriority From 948bf6fd77c4dfc930f60a993ca2e54a3a8ae1a7 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 12:12:29 -0400 Subject: [PATCH 02/16] Added initial simplification tests for new patterns --- .../conversion/relational_simplification.py | 104 +++++++++++++++--- tests/test_pipeline_defog_custom.py | 85 ++++++++++++++ tests/test_plan_refsols/simplification_4.txt | 4 + .../simplification_4_ansi.sql | 46 ++++++++ .../simplification_4_sqlite.sql | 54 +++++++++ 5 files changed, 278 insertions(+), 15 deletions(-) create mode 100644 tests/test_plan_refsols/simplification_4.txt create mode 100644 tests/test_sql_refsols/simplification_4_ansi.sql create mode 100644 tests/test_sql_refsols/simplification_4_sqlite.sql diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index d2905d729..e9fc5dde8 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -373,7 +373,7 @@ def simplify_function_literal_comparison( pydop.LET, expr.data_type, [ - func_expr.inputs[0], + self.switch_operator(func_expr, pydop.MONTH), LiteralExpression((lit_expr.value * 3) - 2, NumericType()), ], ) @@ -392,7 +392,7 @@ def simplify_function_literal_comparison( pydop.LEQ, expr.data_type, [ - func_expr.inputs[0], + self.switch_operator(func_expr, pydop.MONTH), LiteralExpression(lit_expr.value * 3, NumericType()), ], ) @@ -408,10 +408,10 @@ def simplify_function_literal_comparison( # QUARTER(x) > 1 <=> MONTH(X) > 3 if lit_expr.value in (1, 2, 3): result = CallExpression( - pydop.LET, + pydop.GRT, expr.data_type, [ - func_expr.inputs[0], + self.switch_operator(func_expr, pydop.MONTH), LiteralExpression(lit_expr.value * 3, NumericType()), ], ) @@ -430,7 +430,7 @@ def simplify_function_literal_comparison( pydop.GEQ, expr.data_type, [ - func_expr.inputs[0], + self.switch_operator(func_expr, pydop.MONTH), LiteralExpression((lit_expr.value * 3) - 2, NumericType()), ], ) @@ -440,35 +440,96 @@ def simplify_function_literal_comparison( # QUARTER(x) >= 6 <=> KEEP_IF(False, PRESENT(x)) elif lit_expr.value > 4: result = conditional_false - # MONTH(x) > 0 <=> KEEP_IF(True, PRESENT(x)) (same for other units) - # MONTH(x) != -3 <=> KEEP_IF(True, PRESENT(x)) (same for other units) + # MONTH(x) > 0 <=> KEEP_IF(True, PRESENT(x)) (same for day) + # MONTH(x) != -3 <=> KEEP_IF(True, PRESENT(x)) (same for day) case ( pydop.GRT | pydop.NEQ, - pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + pydop.MONTH | pydop.DAY, NumericType(), ) if isinstance(lit_expr.value, int) and lit_expr.value < 1: result = conditional_true - # MONTH(x) >= 1 <=> KEEP_IF(True, PRESENT(x)) (same for other units) + # MONTH(x) >= 1 <=> KEEP_IF(True, PRESENT(x)) (same for day) case ( pydop.GEQ, - pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + pydop.MONTH | pydop.DAY, NumericType(), ) if isinstance(lit_expr.value, int) and lit_expr.value <= 1: result = conditional_true - # MONTH(x) < 1 <=> KEEP_IF(False, PRESENT(x)) (same for other units) + # MONTH(x) < 1 <=> KEEP_IF(False, PRESENT(x)) (same for day) case ( pydop.LET, - pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + pydop.MONTH | pydop.DAY, NumericType(), ) if isinstance(lit_expr.value, int) and lit_expr.value <= 1: result = conditional_false - # MONTH(x) <= 0 <=> KEEP_IF(False, PRESENT(x)) (same for other units) + # MONTH(x) <= 0 <=> KEEP_IF(False, PRESENT(x)) (same for day) + # MONTH(x) == 0 <=> KEEP_IF(False, PRESENT(x)) (same for day) case ( - pydop.LEQ, - pydop.MONTH | pydop.DAY | pydop.HOUR | pydop.MINUTE | pydop.SECOND, + pydop.LEQ | pydop.EQU, + pydop.MONTH | pydop.DAY, NumericType(), ) if isinstance(lit_expr.value, int) and lit_expr.value < 1: result = conditional_false + # HOUR(x) <= -1 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + # HOUR(x) == -1 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + case ( + pydop.LEQ | pydop.EQU, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value < 0: + result = conditional_false + # HOUR(x) < 0 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + case ( + pydop.LET, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value <= 0: + result = conditional_false + # HOUR(x) > -1 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + # HOUR(x) != -1 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + case ( + pydop.GRT | pydop.NEQ, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value < 0: + result = conditional_true + # HOUR(x) >= 0 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + case ( + pydop.GEQ, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value <= 0: + result = conditional_true + # HOUR(x) > 60 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + # HOUR(x) == 60 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + case ( + pydop.GRT | pydop.EQU, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value >= 60: + result = conditional_false + # HOUR(x) < 61 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + # HOUR(x) != 61 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + case ( + pydop.LET | pydop.NEQ, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value > 60: + result = conditional_true + # HOUR(x) <= 60 <=> KEEP_IF(True, PRESENT(x)) (same for minute/second) + case ( + pydop.LEQ, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value >= 60: + result = conditional_true + # HOUR(x) >= 61 <=> KEEP_IF(False, PRESENT(x)) (same for minute/second) + case ( + pydop.GEQ, + pydop.HOUR | pydop.MINUTE | pydop.SECOND, + NumericType(), + ) if isinstance(lit_expr.value, int) and lit_expr.value > 60: + result = conditional_false case _: # Fall back to the original expression by default. pass @@ -955,6 +1016,19 @@ def simplify_function_call( output_predicates |= arg_predicates[0] & PredicateSet( not_null=True, not_negative=True ) + + # DATETIME(DATETIME(u, v, w), x, y, z) -> DATETIME(u, v, w, x, y, z) + case pydop.DATETIME: + if ( + isinstance(expr.inputs[0], CallExpression) + and expr.inputs[0].op == pydop.DATETIME + ): + output_expr = CallExpression( + pydop.DATETIME, + expr.data_type, + expr.inputs[0].inputs + expr.inputs[1:], + ) + case _: # All other operators remain non-simplified. pass diff --git a/tests/test_pipeline_defog_custom.py b/tests/test_pipeline_defog_custom.py index f85afa9c2..109bf3151 100644 --- a/tests/test_pipeline_defog_custom.py +++ b/tests/test_pipeline_defog_custom.py @@ -1879,6 +1879,91 @@ def get_day_of_week( ), id="simplification_3", ), + pytest.param( + PyDoughPandasTest( + "result = (" + " transactions" + " .WHERE(YEAR(date_time) == 2023)" + " .WHERE((RANKING(by=date_time.ASC()) == 1) | (RANKING(by=date_time.DESC()) == 1))" + " .CALCULATE(" + " date_time," + " s00 = DATETIME(DATETIME(date_time, 'start of week'), '-8 weeks')," # -> DATETIME(date_time, 'start of week', '-8 weeks') + " s01 = QUARTER(date_time) == 0," # KEEP_IF(False, PRESENT(date_time)) + " s02 = 1 == QUARTER(date_time)," # ISIN(MONTH(date_time), [1,2,3]) + " s03 = QUARTER(date_time) == 2," # ISIN(MONTH(date_time), [4,5,6]) + " s04 = 3 == QUARTER(date_time)," # ISIN(MONTH(date_time), [7,8,9]) + " s05 = QUARTER(date_time) == 4," # ISIN(MONTH(date_time), [10,11,12]) + " s06 = 5 == QUARTER(date_time)," # KEEP_IF(False, PRESENT(date_time)) + " s07 = 1 > QUARTER(date_time)," # KEEP_IF(False, PRESENT(date_time)) + " s08 = QUARTER(date_time) < 2," # MONTH(date_time) < 4 + " s09 = 3 > QUARTER(date_time)," # MONTH(date_time) < 7 + " s10 = QUARTER(date_time) < 4," # MONTH(date_time) < 10 + " s11 = 5 > QUARTER(date_time)," # KEEP_IF(True, PRESENT(date_time)) + " s12 = QUARTER(date_time) <= 0," # KEEP_IF(False, PRESENT(date_time)) + " s13 = 1 >= QUARTER(date_time)," # MONTH(date_time) <= 3 + " s14 = QUARTER(date_time) <= 2," # MONTH(date_time) <= 6 + " s15 = 3 >= QUARTER(date_time)," # MONTH(date_time) <= 9 + " s16 = QUARTER(date_time) <= 4," # KEEP_IF(True, PRESENT(date_time)) + " s17 = 0 < QUARTER(date_time)," # KEEP_IF(True, PRESENT(date_time)) + " s18 = QUARTER(date_time) > 1," # MONTH(date_time) > 3 + " s19 = 2 < QUARTER(date_time)," # MONTH(date_time) > 6 + " s20 = QUARTER(date_time) > 3," # MONTH(date_time) > 9 + " s21 = 4 < QUARTER(date_time)," # KEEP_IF(False, PRESENT(date_time)) + " s22 = 1 <= QUARTER(date_time)," # KEEP_IF(True, PRESENT(date_time)) + " s23 = QUARTER(date_time) >= 2," # MONTH(date_time) >= 4 + " s24 = 3 <= QUARTER(date_time)," # MONTH(date_time) >= 7 + " s25 = QUARTER(date_time) >= 4," # MONTH(date_time) >= 10 + " s26 = 5 <= QUARTER(date_time)," # KEEP_IF(False, PRESENT(date_time)) + " s27 = QUARTER(date_time) != 0," # KEEP_IF(True, PRESENT(date_time)) + " s28 = 1 != QUARTER(date_time)," # NOT(ISIN(MONTH(date_time), [1,2,3])) + " s29 = QUARTER(date_time) != 2," # NOT(ISIN(MONTH(date_time), [4,5,6])) + " s30 = 3 != QUARTER(date_time)," # NOT(ISIN(MONTH(date_time), [7,8,9])) + " s31 = QUARTER(date_time) != 4," # NOT(ISIN(MONTH(date_time), [10,11,12])) + " s32 = 5 != QUARTER(date_time)," # KEEP_IF(True, PRESENT(date_time)) + "))", + "Broker", + lambda: pd.DataFrame( + { + "date_time": ["2023-01-15 10:00:00", "2023-04-03 16:15:00"], + "s00": ["2022-11-20", "2023-02-05"], + "s01": [0, 0], + "s02": [1, 0], + "s03": [0, 1], + "s04": [0, 0], + "s05": [0, 0], + "s06": [0, 0], + "s07": [0, 0], + "s08": [1, 0], + "s09": [1, 1], + "s10": [1, 1], + "s11": [1, 1], + "s12": [0, 0], + "s13": [1, 0], + "s14": [1, 1], + "s15": [1, 1], + "s16": [1, 1], + "s17": [1, 1], + "s18": [0, 1], + "s19": [0, 0], + "s20": [0, 0], + "s21": [0, 0], + "s22": [1, 1], + "s23": [0, 1], + "s24": [0, 0], + "s25": [0, 0], + "s26": [0, 0], + "s27": [1, 1], + "s28": [0, 1], + "s29": [1, 0], + "s30": [1, 1], + "s31": [1, 1], + "s32": [1, 1], + } + ), + "simplification_4", + ), + id="simplification_4", + ), ], ) def defog_custom_pipeline_test_data(request) -> PyDoughPandasTest: diff --git a/tests/test_plan_refsols/simplification_4.txt b/tests/test_plan_refsols/simplification_4.txt new file mode 100644 index 000000000..903c80e55 --- /dev/null +++ b/tests/test_plan_refsols/simplification_4.txt @@ -0,0 +1,4 @@ +ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime)))], orderings=[]) + FILTER(condition=RANKING(args=[], partition=[], order=[(sbTxDateTime):asc_last]) == 1:numeric | RANKING(args=[], partition=[], order=[(sbTxDateTime):desc_first]) == 1:numeric, columns={'sbTxDateTime': sbTxDateTime}) + FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime}) + SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime}) diff --git a/tests/test_sql_refsols/simplification_4_ansi.sql b/tests/test_sql_refsols/simplification_4_ansi.sql new file mode 100644 index 000000000..f748aa2af --- /dev/null +++ b/tests/test_sql_refsols/simplification_4_ansi.sql @@ -0,0 +1,46 @@ +WITH _t1 AS ( + SELECT + sbtxdatetime + FROM main.sbtransaction + WHERE + EXTRACT(YEAR FROM CAST(sbtxdatetime AS DATETIME)) = 2023 + QUALIFY + ROW_NUMBER() OVER (ORDER BY sbtxdatetime DESC NULLS FIRST) = 1 + OR ROW_NUMBER() OVER (ORDER BY sbtxdatetime NULLS LAST) = 1 +) +SELECT + sbtxdatetime AS date_time, + DATE_ADD(DATE_TRUNC('WEEK', CAST(sbtxdatetime AS TIMESTAMP)), -8, 'WEEK') AS s00, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s01, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (1, 2, 3) AS s02, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (4, 5, 6) AS s03, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (7, 8, 9) AS s04, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (10, 11, 12) AS s05, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s06, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s07, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 4 AS s08, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 7 AS s09, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 10 AS s10, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s11, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s12, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 3 AS s13, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 6 AS s14, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 9 AS s15, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s16, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s17, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 3 AS s18, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 6 AS s19, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 9 AS s20, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s21, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s22, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 4 AS s23, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 7 AS s24, + EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 10 AS s25, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s26, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s27, + NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (1, 2, 3) AS s28, + NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (4, 5, 6) AS s29, + NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (7, 8, 9) AS s30, + NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (10, 11, 12) AS s31, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32 +FROM _t1 diff --git a/tests/test_sql_refsols/simplification_4_sqlite.sql b/tests/test_sql_refsols/simplification_4_sqlite.sql new file mode 100644 index 000000000..bca4e8f95 --- /dev/null +++ b/tests/test_sql_refsols/simplification_4_sqlite.sql @@ -0,0 +1,54 @@ +WITH _t AS ( + SELECT + sbtxdatetime, + ROW_NUMBER() OVER (ORDER BY sbtxdatetime) AS _w, + ROW_NUMBER() OVER (ORDER BY sbtxdatetime DESC) AS _w_2 + FROM main.sbtransaction + WHERE + CAST(STRFTIME('%Y', sbtxdatetime) AS INTEGER) = 2023 +) +SELECT + sbtxdatetime AS date_time, + DATE( + sbtxdatetime, + '-' || CAST(( + CAST(STRFTIME('%w', DATETIME(sbtxdatetime)) AS INTEGER) + 6 + ) % 7 AS TEXT) || ' days', + 'start of day', + '-56 day' + ) AS s00, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s01, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (1, 2, 3) AS s02, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (4, 5, 6) AS s03, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (7, 8, 9) AS s04, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (10, 11, 12) AS s05, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s06, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s07, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 4 AS s08, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 7 AS s09, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 10 AS s10, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s11, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s12, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 3 AS s13, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 6 AS s14, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 9 AS s15, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s16, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s17, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 3 AS s18, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 6 AS s19, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 9 AS s20, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s21, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s22, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 4 AS s23, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 7 AS s24, + CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 10 AS s25, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s26, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s27, + NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (1, 2, 3) AS s28, + NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (4, 5, 6) AS s29, + NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (7, 8, 9) AS s30, + NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (10, 11, 12) AS s31, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32 +FROM _t +WHERE + _w = 1 OR _w_2 = 1 From fdc92ae7e6e64b3de13f447b7cbb54d4b5dd0cd5 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 12:32:52 -0400 Subject: [PATCH 03/16] Added datetime literal extraction simplification and tests --- .../conversion/relational_simplification.py | 62 +++++++++++++++++++ tests/test_pipeline_defog_custom.py | 58 +++++++++++++++++ .../quarter_function_test.txt | 2 +- tests/test_plan_refsols/simplification_4.txt | 2 +- tests/test_plan_refsols/smoke_b.txt | 2 +- .../datetime_functions_ansi.sql | 16 ++--- .../datetime_functions_sqlite.sql | 16 ++--- .../datetime_sampler_ansi.sql | 26 ++++---- .../datetime_sampler_sqlite.sql | 26 ++++---- .../simplification_4_ansi.sql | 30 ++++++++- .../simplification_4_sqlite.sql | 30 ++++++++- tests/test_sql_refsols/smoke_b_ansi.sql | 2 +- tests/test_sql_refsols/smoke_b_sqlite.sql | 2 +- 13 files changed, 225 insertions(+), 49 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index e9fc5dde8..a3008714c 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -9,8 +9,11 @@ __all__ = ["simplify_expressions"] +import datetime from dataclasses import dataclass +import pandas as pd + import pydough.pydough_operators as pydop from pydough.relational import ( Aggregate, @@ -535,6 +538,49 @@ def simplify_function_literal_comparison( pass return result + def simplify_datetime_literal_part( + self, + expr: RelationalExpression, + op: pydop.PyDoughExpressionOperator, + lit_expr: LiteralExpression, + ) -> RelationalExpression: + """ + TODO + """ + ts: pd.Timestamp | None = None + if isinstance(lit_expr.value, (str, datetime.date)): + try: + ts = pd.Timestamp(lit_expr.value) + except Exception: + return expr + elif isinstance(lit_expr.value, pd.Timestamp): + ts = lit_expr.value + + # Fall back to the original expression by default. + if ts is None: + return expr + + # Otherwise, extract the relevant part from the timestamp and return it + # as a literal. + match op: + case pydop.YEAR: + return LiteralExpression(ts.year, NumericType()) + case pydop.QUARTER: + quarter: int = ((ts.month - 1) // 3) + 1 + return LiteralExpression(quarter, NumericType()) + case pydop.MONTH: + return LiteralExpression(ts.month, NumericType()) + case pydop.DAY: + return LiteralExpression(ts.day, NumericType()) + case pydop.HOUR: + return LiteralExpression(ts.hour, NumericType()) + case pydop.MINUTE: + return LiteralExpression(ts.minute, NumericType()) + case pydop.SECOND: + return LiteralExpression(ts.second, NumericType()) + case _: + return expr + def simplify_function_call( self, expr: CallExpression, @@ -1029,6 +1075,22 @@ def simplify_function_call( expr.inputs[0].inputs + expr.inputs[1:], ) + # YEAR(literal_datetime) -> can infer the year as a literal + # (same for QUARTER, MONTH, DAY, HOUR, MINUTE, SECOND) + case ( + pydop.YEAR + | pydop.QUARTER + | pydop.MONTH + | pydop.DAY + | pydop.HOUR + | pydop.MINUTE + | pydop.SECOND + ): + if isinstance(expr.inputs[0], LiteralExpression): + output_expr = self.simplify_datetime_literal_part( + expr, expr.op, expr.inputs[0] + ) + case _: # All other operators remain non-simplified. pass diff --git a/tests/test_pipeline_defog_custom.py b/tests/test_pipeline_defog_custom.py index 109bf3151..1118bf0ab 100644 --- a/tests/test_pipeline_defog_custom.py +++ b/tests/test_pipeline_defog_custom.py @@ -3,6 +3,7 @@ schemas. """ +import datetime import re from collections.abc import Callable @@ -1920,6 +1921,34 @@ def get_day_of_week( " s30 = 3 != QUARTER(date_time)," # NOT(ISIN(MONTH(date_time), [7,8,9])) " s31 = QUARTER(date_time) != 4," # NOT(ISIN(MONTH(date_time), [10,11,12])) " s32 = 5 != QUARTER(date_time)," # KEEP_IF(True, PRESENT(date_time)) + " s33 = YEAR('2024-08-13 12:45:59')," # 2024 + " s34 = QUARTER('2024-08-13 12:45:59')," # 3 + " s35 = MONTH('2024-08-13 12:45:59')," # 8 + " s36 = DAY('2024-08-13 12:45:59')," # 13 + " s37 = HOUR('2024-08-13 12:45:59')," # 12 + " s38 = MINUTE('2024-08-13 12:45:59')," # 45 + " s39 = SECOND('2024-08-13 12:45:59')," # 59 + " s40 = YEAR(datetime.date(2020, 1, 31))," # 2024 + " s41 = QUARTER(datetime.date(2020, 1, 31))," # 1 + " s42 = MONTH(datetime.date(2020, 1, 31))," # 1 + " s43 = DAY(datetime.date(2020, 1, 31))," # 31 + " s44 = HOUR(datetime.date(2020, 1, 31))," # 0 + " s45 = MINUTE(datetime.date(2020, 1, 31))," # 0 + " s46 = SECOND(datetime.date(2020, 1, 31))," # 0 + " s47 = YEAR(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 2023 + " s48 = QUARTER(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 3 + " s49 = MONTH(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 7 + " s50 = DAY(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 4 + " s51 = HOUR(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 6 + " s52 = MINUTE(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 55 + " s53 = SECOND(datetime.datetime(2023, 7, 4, 6, 55, 0))," # 0 + " s54 = YEAR(pd.Timestamp('1999-12-31 23:59:58'))," # 1999 + " s55 = QUARTER(pd.Timestamp('1999-12-31 23:59:58'))," # 4 + " s56 = MONTH(pd.Timestamp('1999-12-31 23:59:58'))," # 12 + " s57 = DAY(pd.Timestamp('1999-12-31 23:59:58'))," # 31 + " s58 = HOUR(pd.Timestamp('1999-12-31 23:59:58'))," # 23 + " s59 = MINUTE(pd.Timestamp('1999-12-31 23:59:58'))," # 59 + " s60 = SECOND(pd.Timestamp('1999-12-31 23:59:58'))," # 58 "))", "Broker", lambda: pd.DataFrame( @@ -1958,9 +1987,38 @@ def get_day_of_week( "s30": [1, 1], "s31": [1, 1], "s32": [1, 1], + "s33": [2024, 2024], + "s34": [3, 3], + "s35": [8, 8], + "s36": [13, 13], + "s37": [12, 12], + "s38": [45, 45], + "s39": [59, 59], + "s40": [2020, 2020], + "s41": [1, 1], + "s42": [1, 1], + "s43": [31, 31], + "s44": [0, 0], + "s45": [0, 0], + "s46": [0, 0], + "s47": [2023, 2023], + "s48": [3, 3], + "s49": [7, 7], + "s50": [4, 4], + "s51": [6, 6], + "s52": [55, 55], + "s53": [0, 0], + "s54": [1999, 1999], + "s55": [4, 4], + "s56": [12, 12], + "s57": [31, 31], + "s58": [23, 23], + "s59": [59, 59], + "s60": [58, 58], } ), "simplification_4", + kwargs={"pd": pd, "datetime": datetime}, ), id="simplification_4", ), diff --git a/tests/test_plan_refsols/quarter_function_test.txt b/tests/test_plan_refsols/quarter_function_test.txt index d323a0a5b..fb423657e 100644 --- a/tests/test_plan_refsols/quarter_function_test.txt +++ b/tests/test_plan_refsols/quarter_function_test.txt @@ -1,2 +1,2 @@ -ROOT(columns=[('_expr0', QUARTER('2023-01-15':string)), ('_expr1', QUARTER('2023-02-28':string)), ('_expr2', QUARTER('2023-03-31':string)), ('_expr3', QUARTER('2023-04-01':string)), ('_expr4', QUARTER('2023-05-15':string)), ('_expr5', QUARTER('2023-06-30':string)), ('_expr6', QUARTER('2023-07-01':string)), ('_expr7', QUARTER('2023-08-15':string)), ('_expr8', QUARTER('2023-09-30':string)), ('_expr9', QUARTER('2023-10-01':string)), ('_expr10', QUARTER('2023-11-15':string)), ('_expr11', QUARTER('2023-12-31':string)), ('_expr12', QUARTER(Timestamp('2024-02-29 12:30:45'):datetime)), ('q1_jan', DATETIME('2023-01-15 12:30:45':string, 'start of quarter':string)), ('q1_feb', DATETIME('2023-02-28 12:30:45':string, 'start of quarter':string)), ('q1_mar', DATETIME('2023-03-31':string, 'start of quarter':string)), ('q2_apr', DATETIME('2023-04-01':string, 'start of quarter':string)), ('q2_may', DATETIME('2023-05-15 12:30:45':string, 'start of quarter':string)), ('q2_jun', DATETIME('2023-06-30 12:30:45':string, 'start of quarter':string)), ('q3_jul', DATETIME('2023-07-01 12:30:45':string, 'start of quarter':string)), ('q3_aug', DATETIME('2023-08-15':string, 'start of quarter':string)), ('q3_sep', DATETIME('2023-09-30':string, 'start of quarter':string)), ('q4_oct', DATETIME('2023-10-01':string, 'start of quarter':string)), ('q4_nov', DATETIME('2023-11-15':string, 'start of quarter':string)), ('q4_dec', DATETIME('2023-12-31':string, 'start of quarter':string)), ('ts_q1', DATETIME(Timestamp('2024-02-29 12:30:45'):datetime, 'start of quarter':string)), ('alias1', DATETIME('2023-05-15':string, 'START OF QUARTER':string)), ('alias2', DATETIME('2023-08-15':string, 'Start Of Quarter':string)), ('alias3', DATETIME('2023-11-15':string, '\n Start Of\tQuarter\n\n':string)), ('alias4', DATETIME('2023-02-15':string, '\tSTART\tOF\tquarter\t':string)), ('chain1', DATETIME('2023-05-15':string, 'start of quarter':string, '+1 day':string, '+2 hours':string)), ('chain2', DATETIME('2023-08-15':string, 'start of quarter':string, 'start of day':string)), ('chain3', DATETIME('2023-11-15':string, '-1 month':string, 'start of quarter':string)), ('plus_1q', DATETIME('2023-01-15 12:30:45':string, '+1 quarter':string)), ('plus_2q', DATETIME('2023-01-15 12:30:45':string, '+2 quarters':string)), ('plus_3q', DATETIME('2023-01-15':string, '+3 quarters':string)), ('minus_1q', DATETIME('2023-01-15 12:30:45':string, '-1 quarter':string)), ('minus_2q', DATETIME('2023-01-15 12:30:45':string, '-2 quarters':string)), ('minus_3q', DATETIME('2023-01-15':string, '-3 quarters':string)), ('syntax1', DATETIME('2023-05-15':string, ' +1 QUARTER ':string)), ('syntax2', DATETIME('2023-08-15':string, '+2 Q':string)), ('syntax3', DATETIME('2023-11-15':string, ' \n +\t3 \nQuarters \n\r ':string)), ('syntax4', DATETIME('2023-02-15':string, '\t-\t2\tq\t':string)), ('q_diff1', DATEDIFF('quarter':string, '2023-01-15':string, '2023-04-15':string)), ('q_diff2', DATEDIFF('quarter':string, '2023-01-15':string, '2023-07-15':string)), ('q_diff3', DATEDIFF('quarter':string, '2023-01-15':string, '2023-10-15':string)), ('q_diff4', DATEDIFF('quarter':string, '2023-01-15':string, '2023-12-31':string)), ('q_diff5', DATEDIFF('quarter':string, '2023-01-15':string, '2024-01-15':string)), ('q_diff6', DATEDIFF('quarter':string, '2023-01-15':string, '2024-04-15':string)), ('q_diff7', DATEDIFF('quarter':string, '2022-10-15':string, '2024-04-15':string)), ('q_diff8', DATEDIFF('quarter':string, '2020-01-01':string, '2025-01-01':string)), ('q_diff9', DATEDIFF('quarter':string, '2023-04-15':string, '2023-01-15':string)), ('q_diff10', DATEDIFF('quarter':string, '2024-01-15':string, '2023-01-15':string)), ('q_diff11', DATEDIFF('quarter':string, '2023-03-31':string, '2023-04-01':string)), ('q_diff12', DATEDIFF('quarter':string, '2023-12-31':string, '2024-01-01':string))], orderings=[]) +ROOT(columns=[('_expr0', 1:numeric), ('_expr1', 1:numeric), ('_expr2', 1:numeric), ('_expr3', 2:numeric), ('_expr4', 2:numeric), ('_expr5', 2:numeric), ('_expr6', 3:numeric), ('_expr7', 3:numeric), ('_expr8', 3:numeric), ('_expr9', 4:numeric), ('_expr10', 4:numeric), ('_expr11', 4:numeric), ('_expr12', 1:numeric), ('q1_jan', DATETIME('2023-01-15 12:30:45':string, 'start of quarter':string)), ('q1_feb', DATETIME('2023-02-28 12:30:45':string, 'start of quarter':string)), ('q1_mar', DATETIME('2023-03-31':string, 'start of quarter':string)), ('q2_apr', DATETIME('2023-04-01':string, 'start of quarter':string)), ('q2_may', DATETIME('2023-05-15 12:30:45':string, 'start of quarter':string)), ('q2_jun', DATETIME('2023-06-30 12:30:45':string, 'start of quarter':string)), ('q3_jul', DATETIME('2023-07-01 12:30:45':string, 'start of quarter':string)), ('q3_aug', DATETIME('2023-08-15':string, 'start of quarter':string)), ('q3_sep', DATETIME('2023-09-30':string, 'start of quarter':string)), ('q4_oct', DATETIME('2023-10-01':string, 'start of quarter':string)), ('q4_nov', DATETIME('2023-11-15':string, 'start of quarter':string)), ('q4_dec', DATETIME('2023-12-31':string, 'start of quarter':string)), ('ts_q1', DATETIME(Timestamp('2024-02-29 12:30:45'):datetime, 'start of quarter':string)), ('alias1', DATETIME('2023-05-15':string, 'START OF QUARTER':string)), ('alias2', DATETIME('2023-08-15':string, 'Start Of Quarter':string)), ('alias3', DATETIME('2023-11-15':string, '\n Start Of\tQuarter\n\n':string)), ('alias4', DATETIME('2023-02-15':string, '\tSTART\tOF\tquarter\t':string)), ('chain1', DATETIME('2023-05-15':string, 'start of quarter':string, '+1 day':string, '+2 hours':string)), ('chain2', DATETIME('2023-08-15':string, 'start of quarter':string, 'start of day':string)), ('chain3', DATETIME('2023-11-15':string, '-1 month':string, 'start of quarter':string)), ('plus_1q', DATETIME('2023-01-15 12:30:45':string, '+1 quarter':string)), ('plus_2q', DATETIME('2023-01-15 12:30:45':string, '+2 quarters':string)), ('plus_3q', DATETIME('2023-01-15':string, '+3 quarters':string)), ('minus_1q', DATETIME('2023-01-15 12:30:45':string, '-1 quarter':string)), ('minus_2q', DATETIME('2023-01-15 12:30:45':string, '-2 quarters':string)), ('minus_3q', DATETIME('2023-01-15':string, '-3 quarters':string)), ('syntax1', DATETIME('2023-05-15':string, ' +1 QUARTER ':string)), ('syntax2', DATETIME('2023-08-15':string, '+2 Q':string)), ('syntax3', DATETIME('2023-11-15':string, ' \n +\t3 \nQuarters \n\r ':string)), ('syntax4', DATETIME('2023-02-15':string, '\t-\t2\tq\t':string)), ('q_diff1', DATEDIFF('quarter':string, '2023-01-15':string, '2023-04-15':string)), ('q_diff2', DATEDIFF('quarter':string, '2023-01-15':string, '2023-07-15':string)), ('q_diff3', DATEDIFF('quarter':string, '2023-01-15':string, '2023-10-15':string)), ('q_diff4', DATEDIFF('quarter':string, '2023-01-15':string, '2023-12-31':string)), ('q_diff5', DATEDIFF('quarter':string, '2023-01-15':string, '2024-01-15':string)), ('q_diff6', DATEDIFF('quarter':string, '2023-01-15':string, '2024-04-15':string)), ('q_diff7', DATEDIFF('quarter':string, '2022-10-15':string, '2024-04-15':string)), ('q_diff8', DATEDIFF('quarter':string, '2020-01-01':string, '2025-01-01':string)), ('q_diff9', DATEDIFF('quarter':string, '2023-04-15':string, '2023-01-15':string)), ('q_diff10', DATEDIFF('quarter':string, '2024-01-15':string, '2023-01-15':string)), ('q_diff11', DATEDIFF('quarter':string, '2023-03-31':string, '2023-04-01':string)), ('q_diff12', DATEDIFF('quarter':string, '2023-12-31':string, '2024-01-01':string))], orderings=[]) EMPTYSINGLETON() diff --git a/tests/test_plan_refsols/simplification_4.txt b/tests/test_plan_refsols/simplification_4.txt index 903c80e55..9481539ab 100644 --- a/tests/test_plan_refsols/simplification_4.txt +++ b/tests/test_plan_refsols/simplification_4.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime)))], orderings=[]) +ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s33', 2024:numeric), ('s34', 3:numeric), ('s35', 8:numeric), ('s36', 13:numeric), ('s37', 12:numeric), ('s38', 45:numeric), ('s39', 59:numeric), ('s40', 2020:numeric), ('s41', 1:numeric), ('s42', 1:numeric), ('s43', 31:numeric), ('s44', 0:numeric), ('s45', 0:numeric), ('s46', 0:numeric), ('s47', 2023:numeric), ('s48', 3:numeric), ('s49', 7:numeric), ('s50', 4:numeric), ('s51', 6:numeric), ('s52', 55:numeric), ('s53', 0:numeric), ('s54', 1999:numeric), ('s55', 4:numeric), ('s56', 12:numeric), ('s57', 31:numeric), ('s58', 23:numeric), ('s59', 59:numeric), ('s60', 58:numeric)], orderings=[]) FILTER(condition=RANKING(args=[], partition=[], order=[(sbTxDateTime):asc_last]) == 1:numeric | RANKING(args=[], partition=[], order=[(sbTxDateTime):desc_first]) == 1:numeric, columns={'sbTxDateTime': sbTxDateTime}) FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime}) SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime}) diff --git a/tests/test_plan_refsols/smoke_b.txt b/tests/test_plan_refsols/smoke_b.txt index b2c62233b..10bfda8a7 100644 --- a/tests/test_plan_refsols/smoke_b.txt +++ b/tests/test_plan_refsols/smoke_b.txt @@ -1,3 +1,3 @@ -ROOT(columns=[('key', o_orderkey), ('a', JOIN_STRINGS('_':string, YEAR(o_orderdate), QUARTER(o_orderdate), MONTH(o_orderdate), DAY(o_orderdate))), ('b', JOIN_STRINGS(':':string, DAYNAME(o_orderdate), DAYOFWEEK(o_orderdate))), ('c', DATETIME(o_orderdate, 'start of year':string, '+6 months':string, '-13 days':string)), ('d', DATETIME(o_orderdate, 'start of quarter':string, '+1 year':string, '+25 hours':string)), ('e', DATETIME('2025-01-01 12:35:13':string, 'start of minute':string)), ('f', DATETIME('2025-01-01 12:35:13':string, 'start of hour':string, '+2 quarters':string, '+3 weeks':string)), ('g', DATETIME('2025-01-01 12:35:13':string, 'start of day':string)), ('h', JOIN_STRINGS(';':string, HOUR('2025-01-01 12:35:13':string), MINUTE(DATETIME('2025-01-01 12:35:13':string, '+45 minutes':string)), SECOND(DATETIME('2025-01-01 12:35:13':string, '-7 seconds':string)))), ('i', DATEDIFF('years':string, '1993-05-25 12:45:36':string, o_orderdate)), ('j', DATEDIFF('quarters':string, '1993-05-25 12:45:36':string, o_orderdate)), ('k', DATEDIFF('months':string, '1993-05-25 12:45:36':string, o_orderdate)), ('l', DATEDIFF('weeks':string, '1993-05-25 12:45:36':string, o_orderdate)), ('m', DATEDIFF('days':string, '1993-05-25 12:45:36':string, o_orderdate)), ('n', DATEDIFF('hours':string, '1993-05-25 12:45:36':string, o_orderdate)), ('o', DATEDIFF('minutes':string, '1993-05-25 12:45:36':string, o_orderdate)), ('p', DATEDIFF('seconds':string, '1993-05-25 12:45:36':string, o_orderdate)), ('q', DATETIME(o_orderdate, 'start of week':string))], orderings=[(o_orderkey):asc_first], limit=5:numeric) +ROOT(columns=[('key', o_orderkey), ('a', JOIN_STRINGS('_':string, YEAR(o_orderdate), QUARTER(o_orderdate), MONTH(o_orderdate), DAY(o_orderdate))), ('b', JOIN_STRINGS(':':string, DAYNAME(o_orderdate), DAYOFWEEK(o_orderdate))), ('c', DATETIME(o_orderdate, 'start of year':string, '+6 months':string, '-13 days':string)), ('d', DATETIME(o_orderdate, 'start of quarter':string, '+1 year':string, '+25 hours':string)), ('e', DATETIME('2025-01-01 12:35:13':string, 'start of minute':string)), ('f', DATETIME('2025-01-01 12:35:13':string, 'start of hour':string, '+2 quarters':string, '+3 weeks':string)), ('g', DATETIME('2025-01-01 12:35:13':string, 'start of day':string)), ('h', JOIN_STRINGS(';':string, 12:numeric, MINUTE(DATETIME('2025-01-01 12:35:13':string, '+45 minutes':string)), SECOND(DATETIME('2025-01-01 12:35:13':string, '-7 seconds':string)))), ('i', DATEDIFF('years':string, '1993-05-25 12:45:36':string, o_orderdate)), ('j', DATEDIFF('quarters':string, '1993-05-25 12:45:36':string, o_orderdate)), ('k', DATEDIFF('months':string, '1993-05-25 12:45:36':string, o_orderdate)), ('l', DATEDIFF('weeks':string, '1993-05-25 12:45:36':string, o_orderdate)), ('m', DATEDIFF('days':string, '1993-05-25 12:45:36':string, o_orderdate)), ('n', DATEDIFF('hours':string, '1993-05-25 12:45:36':string, o_orderdate)), ('o', DATEDIFF('minutes':string, '1993-05-25 12:45:36':string, o_orderdate)), ('p', DATEDIFF('seconds':string, '1993-05-25 12:45:36':string, o_orderdate)), ('q', DATETIME(o_orderdate, 'start of week':string))], orderings=[(o_orderkey):asc_first], limit=5:numeric) FILTER(condition=CONTAINS(o_comment, 'fo':string) & ENDSWITH(o_clerk, '5':string) & STARTSWITH(o_orderpriority, '3':string), columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) SCAN(table=tpch.ORDERS, columns={'o_clerk': o_clerk, 'o_comment': o_comment, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) diff --git a/tests/test_sql_refsols/datetime_functions_ansi.sql b/tests/test_sql_refsols/datetime_functions_ansi.sql index 021f5b708..906e8ef1f 100644 --- a/tests/test_sql_refsols/datetime_functions_ansi.sql +++ b/tests/test_sql_refsols/datetime_functions_ansi.sql @@ -6,16 +6,16 @@ SELECT CAST('2025-01-01 00:00:00' AS TIMESTAMP) AS ts_now_5, CAST('1995-10-08 00:00:00' AS TIMESTAMP) AS ts_now_6, EXTRACT(YEAR FROM CAST(o_orderdate AS DATETIME)) AS year_col, - EXTRACT(YEAR FROM CAST('2020-05-01 00:00:00' AS TIMESTAMP)) AS year_py, - EXTRACT(YEAR FROM CAST('1995-10-10 00:00:00' AS TIMESTAMP)) AS year_pd, + 2020 AS year_py, + 1995 AS year_pd, EXTRACT(MONTH FROM CAST(o_orderdate AS DATETIME)) AS month_col, - EXTRACT(MONTH FROM CAST('2025-02-25' AS TIMESTAMP)) AS month_str, - EXTRACT(MONTH FROM CAST('1992-01-01 12:30:45' AS TIMESTAMP)) AS month_dt, + 2 AS month_str, + 1 AS month_dt, EXTRACT(DAY FROM CAST(o_orderdate AS DATETIME)) AS day_col, - EXTRACT(DAY FROM CAST('1996-11-25 10:45:00' AS TIMESTAMP)) AS day_str, - EXTRACT(HOUR FROM CAST('1995-12-01 23:59:59' AS TIMESTAMP)) AS hour_str, - EXTRACT(MINUTE FROM CAST('1995-12-01 23:59:59' AS TIMESTAMP)) AS minute_str, - EXTRACT(SECOND FROM CAST('1992-01-01 00:00:59' AS TIMESTAMP)) AS second_ts, + 25 AS day_str, + 23 AS hour_str, + 59 AS minute_str, + 59 AS second_ts, DATEDIFF(CAST('1992-01-01' AS TIMESTAMP), CAST(o_orderdate AS DATETIME), DAY) AS dd_col_str, DATEDIFF(CAST(o_orderdate AS DATETIME), CAST('1992-01-01' AS TIMESTAMP), DAY) AS dd_str_col, DATEDIFF(CAST(o_orderdate AS DATETIME), CAST('1995-10-10 00:00:00' AS TIMESTAMP), MONTH) AS dd_pd_col, diff --git a/tests/test_sql_refsols/datetime_functions_sqlite.sql b/tests/test_sql_refsols/datetime_functions_sqlite.sql index 265789e56..bf725c4f6 100644 --- a/tests/test_sql_refsols/datetime_functions_sqlite.sql +++ b/tests/test_sql_refsols/datetime_functions_sqlite.sql @@ -6,16 +6,16 @@ SELECT DATE('2025-01-01 00:00:00', 'start of month') AS ts_now_5, DATETIME('1995-10-10 00:00:00', '-2 day') AS ts_now_6, CAST(STRFTIME('%Y', o_orderdate) AS INTEGER) AS year_col, - CAST(STRFTIME('%Y', '2020-05-01 00:00:00') AS INTEGER) AS year_py, - CAST(STRFTIME('%Y', '1995-10-10 00:00:00') AS INTEGER) AS year_pd, + 2020 AS year_py, + 1995 AS year_pd, CAST(STRFTIME('%m', o_orderdate) AS INTEGER) AS month_col, - CAST(STRFTIME('%m', DATETIME('2025-02-25')) AS INTEGER) AS month_str, - CAST(STRFTIME('%m', '1992-01-01 12:30:45') AS INTEGER) AS month_dt, + 2 AS month_str, + 1 AS month_dt, CAST(STRFTIME('%d', o_orderdate) AS INTEGER) AS day_col, - CAST(STRFTIME('%d', DATETIME('1996-11-25 10:45:00')) AS INTEGER) AS day_str, - CAST(STRFTIME('%H', DATETIME('1995-12-01 23:59:59')) AS INTEGER) AS hour_str, - CAST(STRFTIME('%M', DATETIME('1995-12-01 23:59:59')) AS INTEGER) AS minute_str, - CAST(STRFTIME('%S', '1992-01-01 00:00:59') AS INTEGER) AS second_ts, + 25 AS day_str, + 23 AS hour_str, + 59 AS minute_str, + 59 AS second_ts, CAST(( JULIANDAY(DATE(DATETIME('1992-01-01'), 'start of day')) - JULIANDAY(DATE(o_orderdate, 'start of day')) ) AS INTEGER) AS dd_col_str, diff --git a/tests/test_sql_refsols/datetime_sampler_ansi.sql b/tests/test_sql_refsols/datetime_sampler_ansi.sql index 23a502279..c50e22dbd 100644 --- a/tests/test_sql_refsols/datetime_sampler_ansi.sql +++ b/tests/test_sql_refsols/datetime_sampler_ansi.sql @@ -116,23 +116,23 @@ SELECT ) AS _expr57, DATE_ADD(DATE_ADD(CURRENT_TIMESTAMP(), 45, 'MONTH'), -135, 'SECOND') AS _expr58, EXTRACT(YEAR FROM CURRENT_TIMESTAMP()) AS _expr59, - EXTRACT(YEAR FROM CAST('2025-07-04 12:58:45' AS TIMESTAMP)) AS _expr60, - EXTRACT(YEAR FROM CAST('1999-03-14' AS TIMESTAMP)) AS _expr61, + 2025 AS _expr60, + 1999 AS _expr61, EXTRACT(MONTH FROM CURRENT_TIMESTAMP()) AS _expr62, - EXTRACT(MONTH FROM CAST('2001-06-30' AS DATE)) AS _expr63, - EXTRACT(MONTH FROM CAST('1999-03-14' AS TIMESTAMP)) AS _expr64, + 6 AS _expr63, + 3 AS _expr64, EXTRACT(DAY FROM CURRENT_TIMESTAMP()) AS _expr65, - EXTRACT(DAY FROM CAST('2025-07-04 12:58:45' AS TIMESTAMP)) AS _expr66, - EXTRACT(DAY FROM CAST('2025-07-04 12:58:45' AS TIMESTAMP)) AS _expr67, + 4 AS _expr66, + 4 AS _expr67, EXTRACT(HOUR FROM CURRENT_TIMESTAMP()) AS _expr68, - EXTRACT(HOUR FROM CAST('2001-06-30' AS DATE)) AS _expr69, - EXTRACT(HOUR FROM CAST('2024-01-01' AS TIMESTAMP)) AS _expr70, + 0 AS _expr69, + 0 AS _expr70, EXTRACT(MINUTE FROM CURRENT_TIMESTAMP()) AS _expr71, - EXTRACT(MINUTE FROM CAST('2024-12-25 20:30:59' AS TIMESTAMP)) AS _expr72, - EXTRACT(MINUTE FROM CAST('2024-01-01' AS TIMESTAMP)) AS _expr73, - EXTRACT(SECOND FROM CURRENT_TIMESTAMP()) AS _expr74, - EXTRACT(SECOND FROM CAST('2025-07-04 12:58:45' AS TIMESTAMP)) AS _expr75, - EXTRACT(SECOND FROM CAST('1999-03-14' AS TIMESTAMP)) AS _expr76, + 30 AS _expr72, + 0 AS _expr73, + 23 AS _expr74, + 45 AS _expr75, + 0 AS _expr76, DATEDIFF(CURRENT_TIMESTAMP(), CAST('2018-02-14 12:41:06' AS TIMESTAMP), YEAR) AS _expr77, DATEDIFF(CAST('2022-11-24' AS DATE), CAST(o_orderdate AS DATETIME), YEAR) AS _expr78, DATEDIFF(CAST('1999-03-14' AS TIMESTAMP), CAST('2005-06-30' AS DATE), MONTH) AS _expr79, diff --git a/tests/test_sql_refsols/datetime_sampler_sqlite.sql b/tests/test_sql_refsols/datetime_sampler_sqlite.sql index c83935412..2dd7a678c 100644 --- a/tests/test_sql_refsols/datetime_sampler_sqlite.sql +++ b/tests/test_sql_refsols/datetime_sampler_sqlite.sql @@ -86,23 +86,23 @@ SELECT DATETIME('now', '136 hour', '104 minute', '-104 month', '312 day') AS _expr57, DATETIME('now', '45 month', '-135 second') AS _expr58, CAST(STRFTIME('%Y', DATETIME('now')) AS INTEGER) AS _expr59, - CAST(STRFTIME('%Y', '2025-07-04 12:58:45') AS INTEGER) AS _expr60, - CAST(STRFTIME('%Y', DATETIME('1999-03-14')) AS INTEGER) AS _expr61, + 2025 AS _expr60, + 1999 AS _expr61, CAST(STRFTIME('%m', DATETIME('now')) AS INTEGER) AS _expr62, - CAST(STRFTIME('%m', '2001-06-30') AS INTEGER) AS _expr63, - CAST(STRFTIME('%m', DATETIME('1999-03-14')) AS INTEGER) AS _expr64, + 6 AS _expr63, + 3 AS _expr64, CAST(STRFTIME('%d', DATETIME('now')) AS INTEGER) AS _expr65, - CAST(STRFTIME('%d', '2025-07-04 12:58:45') AS INTEGER) AS _expr66, - CAST(STRFTIME('%d', DATETIME('2025-07-04 12:58:45')) AS INTEGER) AS _expr67, + 4 AS _expr66, + 4 AS _expr67, CAST(STRFTIME('%H', DATETIME('now')) AS INTEGER) AS _expr68, - CAST(STRFTIME('%H', '2001-06-30') AS INTEGER) AS _expr69, - CAST(STRFTIME('%H', DATETIME('2024-01-01')) AS INTEGER) AS _expr70, + 0 AS _expr69, + 0 AS _expr70, CAST(STRFTIME('%M', DATETIME('now')) AS INTEGER) AS _expr71, - CAST(STRFTIME('%M', '2024-12-25 20:30:59') AS INTEGER) AS _expr72, - CAST(STRFTIME('%M', DATETIME('2024-01-01')) AS INTEGER) AS _expr73, - CAST(STRFTIME('%S', DATETIME('now')) AS INTEGER) AS _expr74, - CAST(STRFTIME('%S', '2025-07-04 12:58:45') AS INTEGER) AS _expr75, - CAST(STRFTIME('%S', DATETIME('1999-03-14')) AS INTEGER) AS _expr76, + 30 AS _expr72, + 0 AS _expr73, + 24 AS _expr74, + 45 AS _expr75, + 0 AS _expr76, CAST(STRFTIME('%Y', DATETIME('now')) AS INTEGER) - CAST(STRFTIME('%Y', DATETIME('2018-02-14 12:41:06')) AS INTEGER) AS _expr77, CAST(STRFTIME('%Y', '2022-11-24') AS INTEGER) - CAST(STRFTIME('%Y', o_orderdate) AS INTEGER) AS _expr78, ( diff --git a/tests/test_sql_refsols/simplification_4_ansi.sql b/tests/test_sql_refsols/simplification_4_ansi.sql index f748aa2af..31c4c3f7c 100644 --- a/tests/test_sql_refsols/simplification_4_ansi.sql +++ b/tests/test_sql_refsols/simplification_4_ansi.sql @@ -42,5 +42,33 @@ SELECT NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (4, 5, 6) AS s29, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (7, 8, 9) AS s30, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (10, 11, 12) AS s31, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32 + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32, + 2024 AS s33, + 3 AS s34, + 8 AS s35, + 13 AS s36, + 12 AS s37, + 45 AS s38, + 59 AS s39, + 2020 AS s40, + 1 AS s41, + 1 AS s42, + 31 AS s43, + 0 AS s44, + 0 AS s45, + 0 AS s46, + 2023 AS s47, + 3 AS s48, + 7 AS s49, + 4 AS s50, + 6 AS s51, + 55 AS s52, + 0 AS s53, + 1999 AS s54, + 4 AS s55, + 12 AS s56, + 31 AS s57, + 23 AS s58, + 59 AS s59, + 58 AS s60 FROM _t1 diff --git a/tests/test_sql_refsols/simplification_4_sqlite.sql b/tests/test_sql_refsols/simplification_4_sqlite.sql index bca4e8f95..345cff8cf 100644 --- a/tests/test_sql_refsols/simplification_4_sqlite.sql +++ b/tests/test_sql_refsols/simplification_4_sqlite.sql @@ -48,7 +48,35 @@ SELECT NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (4, 5, 6) AS s29, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (7, 8, 9) AS s30, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (10, 11, 12) AS s31, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32 + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32, + 2024 AS s33, + 3 AS s34, + 8 AS s35, + 13 AS s36, + 12 AS s37, + 45 AS s38, + 59 AS s39, + 2020 AS s40, + 1 AS s41, + 1 AS s42, + 31 AS s43, + 0 AS s44, + 0 AS s45, + 0 AS s46, + 2023 AS s47, + 3 AS s48, + 7 AS s49, + 4 AS s50, + 6 AS s51, + 55 AS s52, + 0 AS s53, + 1999 AS s54, + 4 AS s55, + 12 AS s56, + 31 AS s57, + 23 AS s58, + 59 AS s59, + 58 AS s60 FROM _t WHERE _w = 1 OR _w_2 = 1 diff --git a/tests/test_sql_refsols/smoke_b_ansi.sql b/tests/test_sql_refsols/smoke_b_ansi.sql index 9b9680ff6..82e4a6661 100644 --- a/tests/test_sql_refsols/smoke_b_ansi.sql +++ b/tests/test_sql_refsols/smoke_b_ansi.sql @@ -42,7 +42,7 @@ SELECT CAST('2025-01-01 12:35:13' AS TIMESTAMP) AS g, CONCAT_WS( ';', - EXTRACT(HOUR FROM CAST('2025-01-01 12:35:13' AS TIMESTAMP)), + 12, EXTRACT(MINUTE FROM CAST('2025-01-01 13:20:13' AS TIMESTAMP)), EXTRACT(SECOND FROM CAST('2025-01-01 12:35:06' AS TIMESTAMP)) ) AS h, diff --git a/tests/test_sql_refsols/smoke_b_sqlite.sql b/tests/test_sql_refsols/smoke_b_sqlite.sql index 05c96b7ce..e9b20a91b 100644 --- a/tests/test_sql_refsols/smoke_b_sqlite.sql +++ b/tests/test_sql_refsols/smoke_b_sqlite.sql @@ -59,7 +59,7 @@ SELECT DATE('2025-01-01 12:35:13', 'start of day') AS g, CONCAT_WS( ';', - CAST(STRFTIME('%H', DATETIME('2025-01-01 12:35:13')) AS INTEGER), + 12, CAST(STRFTIME('%M', DATETIME('2025-01-01 12:35:13', '45 minute')) AS INTEGER), CAST(STRFTIME('%S', DATETIME('2025-01-01 12:35:13', '-7 second')) AS INTEGER) ) AS h, From c067a00f8c1e6ef50606dd5772d9482411fe5594 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 12:52:53 -0400 Subject: [PATCH 04/16] Added month/day/hour/minute/second edge comparison tests --- .../conversion/relational_simplification.py | 70 +++++++++++++++++-- tests/test_pipeline_defog_custom.py | 48 +++++++++++++ tests/test_plan_refsols/simplification_4.txt | 2 +- .../datetime_sampler_ansi.sql | 2 +- .../datetime_sampler_sqlite.sql | 2 +- .../simplification_4_ansi.sql | 26 ++++++- .../simplification_4_sqlite.sql | 26 ++++++- 7 files changed, 165 insertions(+), 11 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index a3008714c..6a324fb3b 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -285,7 +285,15 @@ def visit_window_expression( def quarter_month_array(self, quarter: int) -> RelationalExpression: """ - TODO + Returns a LiteralExpression containing an array of the months + corresponding to the given quarter. + + Args: + `quarter`: The quarter (1-4) to get the corresponding months for. + + Returns: + A LiteralExpression containing an array of the months in the + given quarter. """ assert 1 <= quarter <= 4 month_arr: list[int] = [3 * (quarter - 1) + i + 1 for i in range(3)] @@ -295,7 +303,16 @@ def switch_operator( self, expr: CallExpression, op: pydop.PyDoughExpressionOperator ) -> RelationalExpression: """ - TODO + Returns a new CallExpression switching the operator of the given + CallExpression to the given operator, keeping the same inputs and data + type. + + Args: + `expr`: The CallExpression whose operator is to be switched. + `op`: The operator to switch to. + + Returns: + A new CallExpression with the given operator. """ return CallExpression(op, expr.data_type, expr.inputs) @@ -303,7 +320,15 @@ def keep_if_not_null( self, source: RelationalExpression, expr: RelationalExpression ) -> RelationalExpression: """ - TODO + Returns a CallExpression that keeps the given expression only if the + source expression is not null. + + Args: + `source`: The source expression to check for nullness. + `expr`: The expression to keep if the source is not null. + + Returns: + A CallExpression representing KEEP_IF(expr, PRESENT(source)). """ source_not_null: RelationalExpression = CallExpression( pydop.PRESENT, source.data_type, [source] @@ -318,7 +343,18 @@ def simplify_function_literal_comparison( lit_expr: LiteralExpression, ) -> RelationalExpression: """ - TODO + Simplifies a comparison between a function call expression and a + literal expression, e.g. `QUARTER(x) == 2` can be simplified to + `ISIN(MONTH(x), [4, 5, 6])`. + + Args: + `expr`: The original expression representing the comparison. This + should be returned if there is no simplification possible. + `op`: The comparison operator (e.g. EQU, NEQ, LET, etc). + `func_expr`: The left argument of the comparison, which is a + function call expression. + `lit_expr`: The right argument of the comparison, which is a + literal expression. """ assert op in (pydop.EQU, pydop.NEQ, pydop.GRT, pydop.GEQ, pydop.LET, pydop.LEQ) result: RelationalExpression = expr @@ -545,10 +581,32 @@ def simplify_datetime_literal_part( lit_expr: LiteralExpression, ) -> RelationalExpression: """ - TODO + Attempts to simplify a datetime part extraction function call with a + literal argument, e.g. `YEAR('2020-05-01')` can be simplified to `2020`. + + Args: + `expr`: The original expression representing the datetime part + extraction. This should be returned if there is no simplification + possible. + `op`: The datetime part extraction operator (e.g. YEAR, MONTH, DAY, + etc). + `lit_expr`: The literal expression argument to the datetime part + extraction function. + + Returns: + The simplified expression if possible, otherwise the original + expression. """ + # Extract a pandas Timestamp from the literal if possible. Allows cases + # where the literal is a native Python datetime/date, a pandas + # Timestamp, or a string without any alphabetic characters (to avoid + # parsing things like 'now' that depend on the current date). ts: pd.Timestamp | None = None - if isinstance(lit_expr.value, (str, datetime.date)): + if isinstance(lit_expr.value, datetime.date): + ts = pd.Timestamp(lit_expr.value) + elif isinstance(lit_expr.value, str) and not any( + c.isalpha() for c in lit_expr.value + ): try: ts = pd.Timestamp(lit_expr.value) except Exception: diff --git a/tests/test_pipeline_defog_custom.py b/tests/test_pipeline_defog_custom.py index 1118bf0ab..3e054829e 100644 --- a/tests/test_pipeline_defog_custom.py +++ b/tests/test_pipeline_defog_custom.py @@ -1949,6 +1949,30 @@ def get_day_of_week( " s58 = HOUR(pd.Timestamp('1999-12-31 23:59:58'))," # 23 " s59 = MINUTE(pd.Timestamp('1999-12-31 23:59:58'))," # 59 " s60 = SECOND(pd.Timestamp('1999-12-31 23:59:58'))," # 58 + " s61 = MONTH(date_time) == 0," # KEEP_IF(False, PRESENT(datetime)) + " s62 = MONTH(date_time) < 1," # KEEP_IF(False, PRESENT(datetime)) + " s63 = MONTH(date_time) <= 0," # KEEP_IF(False, PRESENT(datetime)) + " s64 = MONTH(date_time) != 0," # KEEP_IF(True, PRESENT(datetime)) + " s65 = MONTH(date_time) > 0," # KEEP_IF(True, PRESENT(datetime)) + " s66 = MONTH(date_time) >= 1," # KEEP_IF(True, PRESENT(datetime)) + " s67 = 0 == DAY(date_time)," # KEEP_IF(False, PRESENT(datetime)) + " s68 = 1 > DAY(date_time)," # KEEP_IF(False, PRESENT(datetime)) + " s69 = 0 >= DAY(date_time)," # KEEP_IF(False, PRESENT(datetime)) + " s70 = 0 != DAY(date_time)," # KEEP_IF(True, PRESENT(datetime)) + " s71 = 0 < DAY(date_time)," # KEEP_IF(True, PRESENT(datetime)) + " s72 = 0 <= DAY(date_time)," # KEEP_IF(True, PRESENT(datetime)) + " s73 = HOUR(date_time) == -1," # KEEP_IF(False, PRESENT(datetime)) + " s74 = 61 == MINUTE(date_time)," # KEEP_IF(False, PRESENT(datetime)) + " s75 = -2 != SECOND(date_time)," # KEEP_IF(True, PRESENT(datetime)) + " s76 = HOUR(date_time) != 62," # KEEP_IF(True, PRESENT(datetime)) + " s77 = MINUTE(date_time) < 0," # KEEP_IF(False, PRESENT(datetime)) + " s78 = SECOND(date_time) < 61," # KEEP_IF(True, PRESENT(datetime)) + " s79 = HOUR(date_time) <= -1," # KEEP_IF(False, PRESENT(datetime)) + " s80 = MINUTE(date_time) <= 60," # KEEP_IF(True, PRESENT(datetime)) + " s81 = SECOND(date_time) > -5," # KEEP_IF(True, PRESENT(datetime)) + " s82 = HOUR(date_time) > 60," # KEEP_IF(False, PRESENT(datetime)) + " s83 = MINUTE(date_time) >= 0," # KEEP_IF(True, PRESENT(datetime)) + " s84 = SECOND(date_time) >= 80," # KEEP_IF(False, PRESENT(datetime)) "))", "Broker", lambda: pd.DataFrame( @@ -2015,6 +2039,30 @@ def get_day_of_week( "s58": [23, 23], "s59": [59, 59], "s60": [58, 58], + "s61": [0, 0], + "s62": [0, 0], + "s63": [0, 0], + "s64": [1, 1], + "s65": [1, 1], + "s66": [1, 1], + "s67": [0, 0], + "s68": [0, 0], + "s69": [0, 0], + "s70": [1, 1], + "s71": [1, 1], + "s72": [1, 1], + "s73": [0, 0], + "s74": [0, 0], + "s75": [1, 1], + "s76": [1, 1], + "s77": [0, 0], + "s78": [1, 1], + "s79": [0, 0], + "s80": [1, 1], + "s81": [1, 1], + "s82": [0, 0], + "s83": [1, 1], + "s84": [0, 0], } ), "simplification_4", diff --git a/tests/test_plan_refsols/simplification_4.txt b/tests/test_plan_refsols/simplification_4.txt index 9481539ab..437ca5ae5 100644 --- a/tests/test_plan_refsols/simplification_4.txt +++ b/tests/test_plan_refsols/simplification_4.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s33', 2024:numeric), ('s34', 3:numeric), ('s35', 8:numeric), ('s36', 13:numeric), ('s37', 12:numeric), ('s38', 45:numeric), ('s39', 59:numeric), ('s40', 2020:numeric), ('s41', 1:numeric), ('s42', 1:numeric), ('s43', 31:numeric), ('s44', 0:numeric), ('s45', 0:numeric), ('s46', 0:numeric), ('s47', 2023:numeric), ('s48', 3:numeric), ('s49', 7:numeric), ('s50', 4:numeric), ('s51', 6:numeric), ('s52', 55:numeric), ('s53', 0:numeric), ('s54', 1999:numeric), ('s55', 4:numeric), ('s56', 12:numeric), ('s57', 31:numeric), ('s58', 23:numeric), ('s59', 59:numeric), ('s60', 58:numeric)], orderings=[]) +ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s33', 2024:numeric), ('s34', 3:numeric), ('s35', 8:numeric), ('s36', 13:numeric), ('s37', 12:numeric), ('s38', 45:numeric), ('s39', 59:numeric), ('s40', 2020:numeric), ('s41', 1:numeric), ('s42', 1:numeric), ('s43', 31:numeric), ('s44', 0:numeric), ('s45', 0:numeric), ('s46', 0:numeric), ('s47', 2023:numeric), ('s48', 3:numeric), ('s49', 7:numeric), ('s50', 4:numeric), ('s51', 6:numeric), ('s52', 55:numeric), ('s53', 0:numeric), ('s54', 1999:numeric), ('s55', 4:numeric), ('s56', 12:numeric), ('s57', 31:numeric), ('s58', 23:numeric), ('s59', 59:numeric), ('s60', 58:numeric), ('s61', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s62', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s63', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s64', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s65', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s66', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s67', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s68', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s69', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s70', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s71', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s72', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s73', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s74', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s75', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s76', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s77', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s78', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s79', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s80', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s81', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s82', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s83', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s84', KEEP_IF(False:bool, PRESENT(sbTxDateTime)))], orderings=[]) FILTER(condition=RANKING(args=[], partition=[], order=[(sbTxDateTime):asc_last]) == 1:numeric | RANKING(args=[], partition=[], order=[(sbTxDateTime):desc_first]) == 1:numeric, columns={'sbTxDateTime': sbTxDateTime}) FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime}) SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime}) diff --git a/tests/test_sql_refsols/datetime_sampler_ansi.sql b/tests/test_sql_refsols/datetime_sampler_ansi.sql index c50e22dbd..91760cd81 100644 --- a/tests/test_sql_refsols/datetime_sampler_ansi.sql +++ b/tests/test_sql_refsols/datetime_sampler_ansi.sql @@ -130,7 +130,7 @@ SELECT EXTRACT(MINUTE FROM CURRENT_TIMESTAMP()) AS _expr71, 30 AS _expr72, 0 AS _expr73, - 23 AS _expr74, + EXTRACT(SECOND FROM CURRENT_TIMESTAMP()) AS _expr74, 45 AS _expr75, 0 AS _expr76, DATEDIFF(CURRENT_TIMESTAMP(), CAST('2018-02-14 12:41:06' AS TIMESTAMP), YEAR) AS _expr77, diff --git a/tests/test_sql_refsols/datetime_sampler_sqlite.sql b/tests/test_sql_refsols/datetime_sampler_sqlite.sql index 2dd7a678c..3454cd89a 100644 --- a/tests/test_sql_refsols/datetime_sampler_sqlite.sql +++ b/tests/test_sql_refsols/datetime_sampler_sqlite.sql @@ -100,7 +100,7 @@ SELECT CAST(STRFTIME('%M', DATETIME('now')) AS INTEGER) AS _expr71, 30 AS _expr72, 0 AS _expr73, - 24 AS _expr74, + CAST(STRFTIME('%S', DATETIME('now')) AS INTEGER) AS _expr74, 45 AS _expr75, 0 AS _expr76, CAST(STRFTIME('%Y', DATETIME('now')) AS INTEGER) - CAST(STRFTIME('%Y', DATETIME('2018-02-14 12:41:06')) AS INTEGER) AS _expr77, diff --git a/tests/test_sql_refsols/simplification_4_ansi.sql b/tests/test_sql_refsols/simplification_4_ansi.sql index 31c4c3f7c..0b06f8cb0 100644 --- a/tests/test_sql_refsols/simplification_4_ansi.sql +++ b/tests/test_sql_refsols/simplification_4_ansi.sql @@ -70,5 +70,29 @@ SELECT 31 AS s57, 23 AS s58, 59 AS s59, - 58 AS s60 + 58 AS s60, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s61, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s62, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s63, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s64, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s65, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s66, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s67, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s68, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s69, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s70, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s71, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s72, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s73, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s74, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s75, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s76, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s77, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s78, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s79, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s80, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s81, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s82, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s83, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s84 FROM _t1 diff --git a/tests/test_sql_refsols/simplification_4_sqlite.sql b/tests/test_sql_refsols/simplification_4_sqlite.sql index 345cff8cf..b73a8da1b 100644 --- a/tests/test_sql_refsols/simplification_4_sqlite.sql +++ b/tests/test_sql_refsols/simplification_4_sqlite.sql @@ -76,7 +76,31 @@ SELECT 31 AS s57, 23 AS s58, 59 AS s59, - 58 AS s60 + 58 AS s60, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s61, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s62, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s63, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s64, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s65, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s66, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s67, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s68, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s69, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s70, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s71, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s72, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s73, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s74, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s75, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s76, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s77, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s78, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s79, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s80, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s81, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s82, + CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s83, + CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s84 FROM _t WHERE _w = 1 OR _w_2 = 1 From 8671cb6a40cbe0dfa11b52daeff1b645dc3b2e6a Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 13:55:22 -0400 Subject: [PATCH 05/16] Added relational filter/join not-null inferrence and null-literal propagation [RUN CI] --- .../conversion/relational_simplification.py | 62 ++++++++++++++++ tests/test_plan_refsols/bad_child_reuse_1.txt | 2 +- tests/test_plan_refsols/bad_child_reuse_2.txt | 2 +- tests/test_plan_refsols/bad_child_reuse_3.txt | 2 +- tests/test_plan_refsols/bad_child_reuse_4.txt | 2 +- tests/test_plan_refsols/common_prefix_ag.txt | 2 +- tests/test_plan_refsols/common_prefix_ah.txt | 2 +- tests/test_plan_refsols/common_prefix_ai.txt | 2 +- tests/test_plan_refsols/common_prefix_aj.txt | 2 +- tests/test_plan_refsols/common_prefix_ak.txt | 2 +- tests/test_plan_refsols/common_prefix_ao.txt | 2 +- tests/test_plan_refsols/common_prefix_i.txt | 2 +- tests/test_plan_refsols/dumb_aggregation.txt | 2 +- tests/test_plan_refsols/simple_cross_8.txt | 43 ++++++----- .../simple_smallest_or_largest.txt | 2 +- tests/test_plan_refsols/simplification_2.txt | 2 +- tests/test_plan_refsols/simplification_4.txt | 2 +- ..._year_cumulative_incident_rate_overall.txt | 2 +- tests/test_plan_refsols/tpch_q18.txt | 2 +- .../defog_broker_gen2_ansi.sql | 2 +- .../defog_broker_gen2_sqlite.sql | 2 +- .../defog_dealership_gen4_ansi.sql | 2 +- .../defog_dealership_gen4_sqlite.sql | 2 +- .../defog_ewallet_adv11_ansi.sql | 4 +- .../defog_ewallet_adv11_sqlite.sql | 4 +- .../simple_smallest_or_largest_sqlite.sql | 8 +-- .../simplification_4_ansi.sql | 72 +++++++++---------- .../simplification_4_sqlite.sql | 72 +++++++++---------- ..._cumulative_incident_rate_overall_ansi.sql | 8 +-- ...umulative_incident_rate_overall_sqlite.sql | 8 +-- tests/test_sql_refsols/tpch_q18_ansi.sql | 2 +- tests/test_sql_refsols/tpch_q18_sqlite.sql | 2 +- 32 files changed, 194 insertions(+), 133 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 6a324fb3b..f4e6a6110 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -171,6 +171,18 @@ def intersect(predicates: list["PredicateSet"]) -> "PredicateSet": """ +NULL_IF_INPUT_NULL_OPS: set[pydop.PyDoughOperator] = ( + NULL_PROPAGATING_OPS | {pydop.GETPART, pydop.DATETIME} +) - {pydop.BOR, pydop.SLICE} +""" +A set of operators that will always output null if any of their inputs are null. +This includes all operators from NULL_PROPAGATING_OPS unless it is possible for +them to output a non-null value even if some inputs are null (e.g. OR, SLICE), +and also include some operators that can return NULL even if none of the inputs +are null (e.g. GETPART or DATEDIFF). +""" + + class SimplificationShuttle(RelationalExpressionShuttle): """ Shuttle implementation for simplifying relational expressions. Has three @@ -668,6 +680,16 @@ def simplify_function_call( union_set: PredicateSet = PredicateSet.union(arg_predicates) intersect_set: PredicateSet = PredicateSet.intersect(arg_predicates) + # Return None if any of the inputs are None and the operator is + # guaranteed to return NULL if any of its inptus are NULL. + if expr.op in NULL_IF_INPUT_NULL_OPS: + if any( + isinstance(arg, LiteralExpression) and arg.value is None + for arg in expr.inputs + ): + self.stack.append(output_predicates) + return LiteralExpression(None, expr.data_type) + # If the call has null propagating rules, all of the arguments are # non-null, the output is guaranteed to be non-null. if expr.op in NULL_PROPAGATING_OPS: @@ -1335,6 +1357,34 @@ def visit_project(self, node: Project) -> None: ) self.stack.append(output_predicates) + def infer_null_predicates_from_condition( + self, + output_predicates: dict[RelationalExpression, PredicateSet], + condition: RelationalExpression, + columns: dict[str, RelationalExpression], + ) -> None: + """ + TODO + """ + from .filter_pushdown import NullReplacementShuttle + + self.shuttle.input_predicates = {} + for expr, preds in output_predicates.items(): + if preds.not_null: + continue + if isinstance(expr, ColumnReference) and expr.name in columns: + expr = columns[expr.name] + if isinstance(expr, ColumnReference): + shuttle: NullReplacementShuttle = NullReplacementShuttle( + {expr.name} + ) + new_cond: RelationalExpression = condition.accept_shuttle(shuttle) + new_cond = new_cond.accept_shuttle(self.shuttle) + if isinstance(new_cond, LiteralExpression) and not bool( + new_cond.value + ): + preds.not_null = True + def visit_filter(self, node: Filter) -> None: output_predicates: dict[RelationalExpression, PredicateSet] = ( self.generic_visit(node) @@ -1344,6 +1394,11 @@ def visit_filter(self, node: Filter) -> None: self.shuttle.stack.pop() for shuttle in self.additional_shuttles: node._condition = node.condition.accept_shuttle(shuttle) + self.infer_null_predicates_from_condition( + output_predicates, + node.condition, + node.columns, + ) self.stack.append(output_predicates) def visit_join(self, node: Join) -> None: @@ -1364,6 +1419,13 @@ def visit_join(self, node: Join) -> None: and expr.input_name != node.default_input_aliases[0] ): preds.not_null = False + + if node.join_type in (JoinType.INNER, JoinType.SEMI): + self.infer_null_predicates_from_condition( + output_predicates, + node.condition, + node.columns, + ) self.stack.append(output_predicates) def visit_limit(self, node: Limit) -> None: diff --git a/tests/test_plan_refsols/bad_child_reuse_1.txt b/tests/test_plan_refsols/bad_child_reuse_1.txt index 7c5488e03..0feec5b7a 100644 --- a/tests/test_plan_refsols/bad_child_reuse_1.txt +++ b/tests/test_plan_refsols/bad_child_reuse_1.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(n_rows, 0:numeric))], orderings=[(c_acctbal):desc_last]) +ROOT(columns=[('cust_key', c_custkey), ('n_orders', n_rows)], orderings=[(c_acctbal):desc_last]) FILTER(condition=n_rows > 0:numeric, columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_rows': n_rows}) LIMIT(limit=10:numeric, columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_rows': n_rows}, orderings=[(c_acctbal):desc_last]) JOIN(condition=t0.c_custkey == t1.o_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'n_rows': t1.n_rows}) diff --git a/tests/test_plan_refsols/bad_child_reuse_2.txt b/tests/test_plan_refsols/bad_child_reuse_2.txt index 4432b9290..bf3d817df 100644 --- a/tests/test_plan_refsols/bad_child_reuse_2.txt +++ b/tests/test_plan_refsols/bad_child_reuse_2.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(n_rows, 0:numeric)), ('n_cust', n_cust)], orderings=[(c_acctbal):desc_last], limit=10:numeric) +ROOT(columns=[('cust_key', c_custkey), ('n_orders', n_rows), ('n_cust', n_cust)], orderings=[(c_acctbal):desc_last], limit=10:numeric) FILTER(condition=n_rows > 0:numeric, columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_cust': n_cust, 'n_rows': n_rows}) PROJECT(columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_cust': RELSIZE(args=[], partition=[c_nationkey], order=[]), 'n_rows': n_rows}) JOIN(condition=t0.c_custkey == t1.o_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'c_nationkey': t0.c_nationkey, 'n_rows': t1.n_rows}) diff --git a/tests/test_plan_refsols/bad_child_reuse_3.txt b/tests/test_plan_refsols/bad_child_reuse_3.txt index 4432b9290..bf3d817df 100644 --- a/tests/test_plan_refsols/bad_child_reuse_3.txt +++ b/tests/test_plan_refsols/bad_child_reuse_3.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(n_rows, 0:numeric)), ('n_cust', n_cust)], orderings=[(c_acctbal):desc_last], limit=10:numeric) +ROOT(columns=[('cust_key', c_custkey), ('n_orders', n_rows), ('n_cust', n_cust)], orderings=[(c_acctbal):desc_last], limit=10:numeric) FILTER(condition=n_rows > 0:numeric, columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_cust': n_cust, 'n_rows': n_rows}) PROJECT(columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_cust': RELSIZE(args=[], partition=[c_nationkey], order=[]), 'n_rows': n_rows}) JOIN(condition=t0.c_custkey == t1.o_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'c_nationkey': t0.c_nationkey, 'n_rows': t1.n_rows}) diff --git a/tests/test_plan_refsols/bad_child_reuse_4.txt b/tests/test_plan_refsols/bad_child_reuse_4.txt index 984b4aad1..532d9b2f3 100644 --- a/tests/test_plan_refsols/bad_child_reuse_4.txt +++ b/tests/test_plan_refsols/bad_child_reuse_4.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(n_rows, 0:numeric))], orderings=[(c_acctbal):desc_last], limit=10:numeric) +ROOT(columns=[('cust_key', c_custkey), ('n_orders', n_rows)], orderings=[(c_acctbal):desc_last], limit=10:numeric) FILTER(condition=DEFAULT_TO(n_rows, 0:numeric) < RELAVG(args=[DEFAULT_TO(n_rows, 0:numeric)], partition=[c_nationkey], order=[]) & n_rows > 0:numeric, columns={'c_acctbal': c_acctbal, 'c_custkey': c_custkey, 'n_rows': n_rows}) JOIN(condition=t0.c_custkey == t1.o_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'c_acctbal': t0.c_acctbal, 'c_custkey': t0.c_custkey, 'c_nationkey': t0.c_nationkey, 'n_rows': t1.n_rows}) JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=PLURAL_ACCESS, columns={'c_acctbal': t1.c_acctbal, 'c_custkey': t1.c_custkey, 'c_nationkey': t1.c_nationkey}) diff --git a/tests/test_plan_refsols/common_prefix_ag.txt b/tests/test_plan_refsols/common_prefix_ag.txt index 13620bfb8..833fc9321 100644 --- a/tests/test_plan_refsols/common_prefix_ag.txt +++ b/tests/test_plan_refsols/common_prefix_ag.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', DEFAULT_TO(sum_n_rows, 0:numeric)), ('n_machine_high_domestic_lines', DEFAULT_TO(sum_sum_n_rows, 0:numeric)), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) +ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', sum_n_rows), ('n_machine_high_domestic_lines', sum_sum_n_rows), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) FILTER(condition=sum_n_rows > 0:numeric & sum_sum_n_rows > 0:numeric, columns={'anything_n_name': anything_n_name, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows, 'sum_sum_n_rows': sum_sum_n_rows, 'sum_sum_sum_revenue': sum_sum_sum_revenue}) AGGREGATE(keys={'n_nationkey': n_nationkey}, aggregations={'anything_n_name': ANYTHING(n_name), 'n_rows': COUNT(), 'sum_n_rows': SUM(n_rows), 'sum_sum_n_rows': SUM(sum_n_rows), 'sum_sum_sum_revenue': SUM(sum_sum_revenue)}) JOIN(condition=t0.n_nationkey == t1.n_nationkey & t0.c_custkey == t1.c_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'n_rows': t1.n_rows, 'sum_n_rows': t1.sum_n_rows, 'sum_sum_revenue': t1.sum_sum_revenue}) diff --git a/tests/test_plan_refsols/common_prefix_ah.txt b/tests/test_plan_refsols/common_prefix_ah.txt index bb3062a31..e89b81d16 100644 --- a/tests/test_plan_refsols/common_prefix_ah.txt +++ b/tests/test_plan_refsols/common_prefix_ah.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', anything_n_name), ('n_machine_high_orders', n_rows), ('n_machine_high_domestic_lines', DEFAULT_TO(sum_n_rows, 0:numeric)), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) +ROOT(columns=[('nation_name', anything_n_name), ('n_machine_high_orders', n_rows), ('n_machine_high_domestic_lines', sum_n_rows), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) FILTER(condition=sum_n_rows > 0:numeric, columns={'anything_n_name': anything_n_name, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows, 'sum_sum_revenue': sum_sum_revenue}) AGGREGATE(keys={'n_nationkey': n_nationkey}, aggregations={'anything_n_name': ANYTHING(n_name), 'n_rows': COUNT(), 'sum_n_rows': SUM(n_rows), 'sum_sum_revenue': SUM(sum_revenue)}) JOIN(condition=t0.c_custkey == t1.c_custkey & t0.n_nationkey == t1.n_nationkey & t0.o_orderkey == t1.o_orderkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'n_rows': t1.n_rows, 'sum_revenue': t1.sum_revenue}) diff --git a/tests/test_plan_refsols/common_prefix_ai.txt b/tests/test_plan_refsols/common_prefix_ai.txt index 722878bb4..568bce687 100644 --- a/tests/test_plan_refsols/common_prefix_ai.txt +++ b/tests/test_plan_refsols/common_prefix_ai.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_domestic_lines', DEFAULT_TO(sum_n_rows, 0:numeric)), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) +ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_domestic_lines', sum_n_rows), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) FILTER(condition=sum_n_rows > 0:numeric, columns={'anything_n_name': anything_n_name, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows, 'sum_sum_revenue': sum_sum_revenue}) AGGREGATE(keys={'n_nationkey': n_nationkey}, aggregations={'anything_n_name': ANYTHING(n_name), 'n_rows': COUNT(), 'sum_n_rows': SUM(n_rows), 'sum_sum_revenue': SUM(sum_revenue)}) JOIN(condition=t0.n_nationkey == t1.n_nationkey & t0.c_custkey == t1.c_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'n_rows': t1.n_rows, 'sum_revenue': t1.sum_revenue}) diff --git a/tests/test_plan_refsols/common_prefix_aj.txt b/tests/test_plan_refsols/common_prefix_aj.txt index 36c32255d..127d53d63 100644 --- a/tests/test_plan_refsols/common_prefix_aj.txt +++ b/tests/test_plan_refsols/common_prefix_aj.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', DEFAULT_TO(sum_n_rows, 0:numeric)), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) +ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', sum_n_rows), ('total_machine_high_domestic_revenue', ROUND(DEFAULT_TO(sum_sum_sum_revenue, 0:numeric), 2:numeric))], orderings=[(anything_n_name):asc_first]) FILTER(condition=sum_n_rows > 0:numeric & sum_sum_n_rows > 0:numeric, columns={'anything_n_name': anything_n_name, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows, 'sum_sum_sum_revenue': sum_sum_sum_revenue}) AGGREGATE(keys={'n_nationkey': n_nationkey}, aggregations={'anything_n_name': ANYTHING(n_name), 'n_rows': COUNT(), 'sum_n_rows': SUM(n_rows), 'sum_sum_n_rows': SUM(sum_n_rows), 'sum_sum_sum_revenue': SUM(sum_sum_revenue)}) JOIN(condition=t0.n_nationkey == t1.n_nationkey & t0.c_custkey == t1.c_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'n_rows': t1.n_rows, 'sum_n_rows': t1.sum_n_rows, 'sum_sum_revenue': t1.sum_sum_revenue}) diff --git a/tests/test_plan_refsols/common_prefix_ak.txt b/tests/test_plan_refsols/common_prefix_ak.txt index 534ab16db..27d74ce5a 100644 --- a/tests/test_plan_refsols/common_prefix_ak.txt +++ b/tests/test_plan_refsols/common_prefix_ak.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', DEFAULT_TO(sum_n_rows, 0:numeric)), ('n_machine_high_domestic_lines', DEFAULT_TO(sum_sum_n_rows, 0:numeric))], orderings=[(anything_n_name):asc_first]) +ROOT(columns=[('nation_name', anything_n_name), ('n_machine_cust', n_rows), ('n_machine_high_orders', sum_n_rows), ('n_machine_high_domestic_lines', sum_sum_n_rows)], orderings=[(anything_n_name):asc_first]) FILTER(condition=sum_n_rows > 0:numeric & sum_sum_n_rows > 0:numeric, columns={'anything_n_name': anything_n_name, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows, 'sum_sum_n_rows': sum_sum_n_rows}) AGGREGATE(keys={'n_nationkey': n_nationkey}, aggregations={'anything_n_name': ANYTHING(n_name), 'n_rows': COUNT(), 'sum_n_rows': SUM(n_rows), 'sum_sum_n_rows': SUM(sum_n_rows)}) JOIN(condition=t0.n_nationkey == t1.n_nationkey & t0.c_custkey == t1.c_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_nationkey': t0.n_nationkey, 'n_rows': t1.n_rows, 'sum_n_rows': t1.sum_n_rows}) diff --git a/tests/test_plan_refsols/common_prefix_ao.txt b/tests/test_plan_refsols/common_prefix_ao.txt index b6b7a4abb..24d8a7b16 100644 --- a/tests/test_plan_refsols/common_prefix_ao.txt +++ b/tests/test_plan_refsols/common_prefix_ao.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(agg_1, 0:numeric)), ('n_no_tax_discount', DEFAULT_TO(n_rows, 0:numeric)), ('n_part_purchases', sum_n_rows)], orderings=[(c_custkey):asc_first], limit=5:numeric) +ROOT(columns=[('cust_key', c_custkey), ('n_orders', DEFAULT_TO(agg_1, 0:numeric)), ('n_no_tax_discount', n_rows), ('n_part_purchases', sum_n_rows)], orderings=[(c_custkey):asc_first], limit=5:numeric) FILTER(condition=DEFAULT_TO(agg_1, 0:numeric) > RELAVG(args=[DEFAULT_TO(agg_1, 0:numeric)], partition=[], order=[]) & n_rows > 0:numeric, columns={'agg_1': agg_1, 'c_custkey': c_custkey, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows}) JOIN(condition=t0.c_custkey == t1.o_custkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'agg_1': t0.n_rows, 'c_custkey': t0.c_custkey, 'n_rows': t1.n_rows, 'sum_n_rows': t0.sum_n_rows}) LIMIT(limit=20:numeric, columns={'c_custkey': c_custkey, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows}, orderings=[(c_custkey):asc_first]) diff --git a/tests/test_plan_refsols/common_prefix_i.txt b/tests/test_plan_refsols/common_prefix_i.txt index b277cf6c4..c1c0c4e8b 100644 --- a/tests/test_plan_refsols/common_prefix_i.txt +++ b/tests/test_plan_refsols/common_prefix_i.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('name', n_name), ('n_customers', n_rows), ('n_selected_orders', DEFAULT_TO(sum_n_rows, 0:numeric))], orderings=[(n_rows):desc_last, (n_name):asc_first], limit=5:numeric) +ROOT(columns=[('name', n_name), ('n_customers', n_rows), ('n_selected_orders', sum_n_rows)], orderings=[(n_rows):desc_last, (n_name):asc_first], limit=5:numeric) JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=SINGULAR_FILTER, columns={'n_name': t0.n_name, 'n_rows': t1.n_rows, 'sum_n_rows': t1.sum_n_rows}) SCAN(table=tpch.NATION, columns={'n_name': n_name, 'n_nationkey': n_nationkey}) FILTER(condition=sum_n_rows > 0:numeric, columns={'c_nationkey': c_nationkey, 'n_rows': n_rows, 'sum_n_rows': sum_n_rows}) diff --git a/tests/test_plan_refsols/dumb_aggregation.txt b/tests/test_plan_refsols/dumb_aggregation.txt index 2a9906b01..5b9eefb5e 100644 --- a/tests/test_plan_refsols/dumb_aggregation.txt +++ b/tests/test_plan_refsols/dumb_aggregation.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', n_name), ('a1', r_name), ('a2', r_name), ('a3', DEFAULT_TO(r_regionkey, 0:numeric)), ('a4', IFF(PRESENT(KEEP_IF(r_regionkey, r_name != 'AMERICA':string)), 1:numeric, 0:numeric)), ('a5', 1:numeric), ('a6', r_regionkey), ('a7', r_name), ('a8', r_regionkey)], orderings=[(n_name):asc_first]) +ROOT(columns=[('nation_name', n_name), ('a1', r_name), ('a2', r_name), ('a3', r_regionkey), ('a4', 1:numeric), ('a5', 1:numeric), ('a6', r_regionkey), ('a7', r_name), ('a8', r_regionkey)], orderings=[(n_name):asc_first]) JOIN(condition=t0.n_regionkey == t1.r_regionkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_name': t0.n_name, 'r_name': t1.r_name, 'r_regionkey': t1.r_regionkey}) LIMIT(limit=2:numeric, columns={'n_name': n_name, 'n_regionkey': n_regionkey}, orderings=[(n_name):asc_first]) SCAN(table=tpch.NATION, columns={'n_name': n_name, 'n_regionkey': n_regionkey}) diff --git a/tests/test_plan_refsols/simple_cross_8.txt b/tests/test_plan_refsols/simple_cross_8.txt index aee7a052f..bf202bed8 100644 --- a/tests/test_plan_refsols/simple_cross_8.txt +++ b/tests/test_plan_refsols/simple_cross_8.txt @@ -1,24 +1,23 @@ ROOT(columns=[('supplier_region', anything_supplier_region), ('customer_region', customer_region), ('region_combinations', region_combinations)], orderings=[]) AGGREGATE(keys={'key_2': key_2, 'r_regionkey': r_regionkey}, aggregations={'anything_supplier_region': ANYTHING(supplier_region), 'customer_region': ANYTHING(r_name), 'region_combinations': COUNT()}) - FILTER(condition=name_18 == supplier_region, columns={'key_2': key_2, 'r_name': r_name, 'r_regionkey': r_regionkey, 'supplier_region': supplier_region}) - JOIN(condition=t0.l_suppkey == t1.s_suppkey, type=LEFT, cardinality=SINGULAR_FILTER, columns={'key_2': t0.key_2, 'name_18': t1.r_name, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) - JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'l_suppkey': t1.l_suppkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) - JOIN(condition=t0.c_custkey == t1.o_custkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'o_orderkey': t1.o_orderkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) - JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=PLURAL_FILTER, columns={'c_custkey': t1.c_custkey, 'key_2': t0.key_2, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) - JOIN(condition=t0.key_2 == t1.n_regionkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'n_nationkey': t1.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) - JOIN(condition=True:bool, type=INNER, cardinality=PLURAL_ACCESS, columns={'key_2': t1.r_regionkey, 'r_name': t1.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.r_name}) - SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) - SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) - SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) - FILTER(condition=c_mktsegment == 'AUTOMOBILE':string, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) - SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_mktsegment': c_mktsegment, 'c_nationkey': c_nationkey}) - FILTER(condition=o_clerk == 'Clerk#000000007':string, columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) - SCAN(table=tpch.ORDERS, columns={'o_clerk': o_clerk, 'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) - FILTER(condition=MONTH(l_shipdate) == 3:numeric & YEAR(l_shipdate) == 1998:numeric, columns={'l_orderkey': l_orderkey, 'l_suppkey': l_suppkey}) - SCAN(table=tpch.LINEITEM, columns={'l_orderkey': l_orderkey, 'l_shipdate': l_shipdate, 'l_suppkey': l_suppkey}) - JOIN(condition=t0.n_regionkey == t1.r_regionkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'r_name': t1.r_name, 's_suppkey': t0.s_suppkey}) - JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) - FILTER(condition=s_acctbal < 0:numeric, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'s_acctbal': s_acctbal, 's_nationkey': s_nationkey, 's_suppkey': s_suppkey}) - SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) - SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) + JOIN(condition=t0.l_suppkey == t1.s_suppkey & t1.r_name == t0.supplier_region, type=INNER, cardinality=SINGULAR_FILTER, columns={'key_2': t0.key_2, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) + JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'l_suppkey': t1.l_suppkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) + JOIN(condition=t0.c_custkey == t1.o_custkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'o_orderkey': t1.o_orderkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) + JOIN(condition=t0.n_nationkey == t1.c_nationkey, type=INNER, cardinality=PLURAL_FILTER, columns={'c_custkey': t1.c_custkey, 'key_2': t0.key_2, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) + JOIN(condition=t0.key_2 == t1.n_regionkey, type=INNER, cardinality=PLURAL_FILTER, columns={'key_2': t0.key_2, 'n_nationkey': t1.n_nationkey, 'r_name': t0.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.supplier_region}) + JOIN(condition=True:bool, type=INNER, cardinality=PLURAL_ACCESS, columns={'key_2': t1.r_regionkey, 'r_name': t1.r_name, 'r_regionkey': t0.r_regionkey, 'supplier_region': t0.r_name}) + SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) + SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) + SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + FILTER(condition=c_mktsegment == 'AUTOMOBILE':string, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) + SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_mktsegment': c_mktsegment, 'c_nationkey': c_nationkey}) + FILTER(condition=o_clerk == 'Clerk#000000007':string, columns={'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) + SCAN(table=tpch.ORDERS, columns={'o_clerk': o_clerk, 'o_custkey': o_custkey, 'o_orderkey': o_orderkey}) + FILTER(condition=MONTH(l_shipdate) == 3:numeric & YEAR(l_shipdate) == 1998:numeric, columns={'l_orderkey': l_orderkey, 'l_suppkey': l_suppkey}) + SCAN(table=tpch.LINEITEM, columns={'l_orderkey': l_orderkey, 'l_shipdate': l_shipdate, 'l_suppkey': l_suppkey}) + JOIN(condition=t0.n_regionkey == t1.r_regionkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'r_name': t1.r_name, 's_suppkey': t0.s_suppkey}) + JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) + FILTER(condition=s_acctbal < 0:numeric, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'s_acctbal': s_acctbal, 's_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + SCAN(table=tpch.REGION, columns={'r_name': r_name, 'r_regionkey': r_regionkey}) diff --git a/tests/test_plan_refsols/simple_smallest_or_largest.txt b/tests/test_plan_refsols/simple_smallest_or_largest.txt index 3f9661a60..fe20b1489 100644 --- a/tests/test_plan_refsols/simple_smallest_or_largest.txt +++ b/tests/test_plan_refsols/simple_smallest_or_largest.txt @@ -1,2 +1,2 @@ -ROOT(columns=[('s1', SMALLEST(20:numeric, 10:numeric)), ('s2', SMALLEST(20:numeric, 20:numeric)), ('s3', SMALLEST(20:numeric, 10:numeric, 0:numeric)), ('s4', SMALLEST(20:numeric, 10:numeric, 10:numeric, -1:numeric, -2:numeric, 100:numeric, -200:numeric)), ('s5', SMALLEST(20:numeric, 10:numeric, None:unknown, 100:numeric, 200:numeric)), ('s6', SMALLEST(20.22:numeric, 10.22:numeric, -0.34:numeric)), ('s7', SMALLEST(datetime.datetime(2025, 1, 1, 0, 0):datetime, datetime.datetime(2024, 1, 1, 0, 0):datetime, datetime.datetime(2023, 1, 1, 0, 0):datetime)), ('s8', SMALLEST('':string, 'alphabet soup':string, 'Hello World':string)), ('s9', SMALLEST(None:unknown, 'alphabet soup':string, 'Hello World':string)), ('l1', LARGEST(20:numeric, 10:numeric)), ('l2', LARGEST(20:numeric, 20:numeric)), ('l3', LARGEST(20:numeric, 10:numeric, 0:numeric)), ('l4', LARGEST(20:numeric, 10:numeric, 10:numeric, -1:numeric, -2:numeric, 100:numeric, -200:numeric, 300:numeric)), ('l5', LARGEST(20:numeric, 10:numeric, None:unknown, 100:numeric, 200:numeric)), ('l6', LARGEST(20.22:numeric, 100.22:numeric, -0.34:numeric)), ('l7', LARGEST(datetime.datetime(2025, 1, 1, 0, 0):datetime, datetime.datetime(2024, 1, 1, 0, 0):datetime, datetime.datetime(2023, 1, 1, 0, 0):datetime)), ('l8', LARGEST('':string, 'alphabet soup':string, 'Hello World':string)), ('l9', LARGEST(None:unknown, 'alphabet soup':string, 'Hello World':string))], orderings=[]) +ROOT(columns=[('s1', SMALLEST(20:numeric, 10:numeric)), ('s2', SMALLEST(20:numeric, 20:numeric)), ('s3', SMALLEST(20:numeric, 10:numeric, 0:numeric)), ('s4', SMALLEST(20:numeric, 10:numeric, 10:numeric, -1:numeric, -2:numeric, 100:numeric, -200:numeric)), ('s5', None:numeric), ('s6', SMALLEST(20.22:numeric, 10.22:numeric, -0.34:numeric)), ('s7', SMALLEST(datetime.datetime(2025, 1, 1, 0, 0):datetime, datetime.datetime(2024, 1, 1, 0, 0):datetime, datetime.datetime(2023, 1, 1, 0, 0):datetime)), ('s8', SMALLEST('':string, 'alphabet soup':string, 'Hello World':string)), ('s9', None:unknown), ('l1', LARGEST(20:numeric, 10:numeric)), ('l2', LARGEST(20:numeric, 20:numeric)), ('l3', LARGEST(20:numeric, 10:numeric, 0:numeric)), ('l4', LARGEST(20:numeric, 10:numeric, 10:numeric, -1:numeric, -2:numeric, 100:numeric, -200:numeric, 300:numeric)), ('l5', None:numeric), ('l6', LARGEST(20.22:numeric, 100.22:numeric, -0.34:numeric)), ('l7', LARGEST(datetime.datetime(2025, 1, 1, 0, 0):datetime, datetime.datetime(2024, 1, 1, 0, 0):datetime, datetime.datetime(2023, 1, 1, 0, 0):datetime)), ('l8', LARGEST('':string, 'alphabet soup':string, 'Hello World':string)), ('l9', None:unknown)], orderings=[]) EMPTYSINGLETON() diff --git a/tests/test_plan_refsols/simplification_2.txt b/tests/test_plan_refsols/simplification_2.txt index f15a3a4fc..00b7f7212 100644 --- a/tests/test_plan_refsols/simplification_2.txt +++ b/tests/test_plan_refsols/simplification_2.txt @@ -1,3 +1,3 @@ -ROOT(columns=[('s00', True:bool), ('s01', False:bool), ('s02', True:bool), ('s03', False:bool), ('s04', True:bool), ('s05', False:bool), ('s06', None:bool), ('s07', None:bool), ('s08', None:bool), ('s09', None:bool), ('s10', None:bool), ('s11', None:bool), ('s12', False:bool), ('s13', True:bool), ('s14', False:bool), ('s15', False:bool), ('s16', True:bool), ('s17', True:bool), ('s18', True:bool), ('s19', False:bool), ('s20', True:bool), ('s21', False:bool), ('s22', True:bool), ('s23', False:bool), ('s24', False:bool), ('s25', True:bool), ('s26', True:bool), ('s27', False:bool), ('s28', True:bool), ('s29', False:bool), ('s30', 8:numeric), ('s31', 'alphabet':string), ('s32', 'SOUP':string), ('s33', True:bool), ('s34', False:bool), ('s35', False:bool), ('s36', True:bool), ('s37', 3.0:numeric), ('s38', n_rows == None:unknown), ('s39', n_rows <= None:unknown), ('s40', n_rows > None:unknown), ('s41', n_rows > None:unknown), ('s42', n_rows >= None:unknown), ('s43', None:unknown + n_rows), ('s44', n_rows - None:unknown), ('s45', None:unknown * n_rows), ('s46', n_rows / None:unknown), ('s47', LIKE(DEFAULT_TO(max_sbCustName, '':string), '%r%':string))], orderings=[]) +ROOT(columns=[('s00', True:bool), ('s01', False:bool), ('s02', True:bool), ('s03', False:bool), ('s04', True:bool), ('s05', False:bool), ('s06', None:bool), ('s07', None:bool), ('s08', None:bool), ('s09', None:bool), ('s10', None:bool), ('s11', None:bool), ('s12', False:bool), ('s13', True:bool), ('s14', False:bool), ('s15', False:bool), ('s16', True:bool), ('s17', True:bool), ('s18', True:bool), ('s19', False:bool), ('s20', True:bool), ('s21', False:bool), ('s22', True:bool), ('s23', False:bool), ('s24', False:bool), ('s25', True:bool), ('s26', True:bool), ('s27', False:bool), ('s28', True:bool), ('s29', False:bool), ('s30', 8:numeric), ('s31', 'alphabet':string), ('s32', 'SOUP':string), ('s33', True:bool), ('s34', False:bool), ('s35', False:bool), ('s36', True:bool), ('s37', 3.0:numeric), ('s38', None:bool), ('s39', None:bool), ('s40', None:bool), ('s41', None:bool), ('s42', None:bool), ('s43', None:unknown), ('s44', None:numeric), ('s45', None:unknown), ('s46', n_rows / None:unknown), ('s47', LIKE(DEFAULT_TO(max_sbCustName, '':string), '%r%':string))], orderings=[]) AGGREGATE(keys={}, aggregations={'max_sbCustName': MAX(sbCustName), 'n_rows': COUNT()}) SCAN(table=main.sbCustomer, columns={'sbCustName': sbCustName}) diff --git a/tests/test_plan_refsols/simplification_4.txt b/tests/test_plan_refsols/simplification_4.txt index 437ca5ae5..8f5218e84 100644 --- a/tests/test_plan_refsols/simplification_4.txt +++ b/tests/test_plan_refsols/simplification_4.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s07', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s12', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s17', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s22', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s27', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s33', 2024:numeric), ('s34', 3:numeric), ('s35', 8:numeric), ('s36', 13:numeric), ('s37', 12:numeric), ('s38', 45:numeric), ('s39', 59:numeric), ('s40', 2020:numeric), ('s41', 1:numeric), ('s42', 1:numeric), ('s43', 31:numeric), ('s44', 0:numeric), ('s45', 0:numeric), ('s46', 0:numeric), ('s47', 2023:numeric), ('s48', 3:numeric), ('s49', 7:numeric), ('s50', 4:numeric), ('s51', 6:numeric), ('s52', 55:numeric), ('s53', 0:numeric), ('s54', 1999:numeric), ('s55', 4:numeric), ('s56', 12:numeric), ('s57', 31:numeric), ('s58', 23:numeric), ('s59', 59:numeric), ('s60', 58:numeric), ('s61', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s62', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s63', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s64', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s65', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s66', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s67', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s68', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s69', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s70', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s71', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s72', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s73', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s74', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s75', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s76', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s77', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s78', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s79', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s80', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s81', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s82', KEEP_IF(False:bool, PRESENT(sbTxDateTime))), ('s83', KEEP_IF(True:bool, PRESENT(sbTxDateTime))), ('s84', KEEP_IF(False:bool, PRESENT(sbTxDateTime)))], orderings=[]) +ROOT(columns=[('date_time', sbTxDateTime), ('s00', DATETIME(sbTxDateTime, 'start of week':string, '-8 weeks':string)), ('s01', False:bool), ('s02', ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric])), ('s03', ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric])), ('s04', ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric])), ('s05', ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric])), ('s06', False:bool), ('s07', False:bool), ('s08', MONTH(sbTxDateTime) < 4:numeric), ('s09', MONTH(sbTxDateTime) < 7:numeric), ('s10', MONTH(sbTxDateTime) < 10:numeric), ('s11', True:bool), ('s12', False:bool), ('s13', MONTH(sbTxDateTime) <= 3:numeric), ('s14', MONTH(sbTxDateTime) <= 6:numeric), ('s15', MONTH(sbTxDateTime) <= 9:numeric), ('s16', True:bool), ('s17', True:bool), ('s18', MONTH(sbTxDateTime) > 3:numeric), ('s19', MONTH(sbTxDateTime) > 6:numeric), ('s20', MONTH(sbTxDateTime) > 9:numeric), ('s21', False:bool), ('s22', True:bool), ('s23', MONTH(sbTxDateTime) >= 4:numeric), ('s24', MONTH(sbTxDateTime) >= 7:numeric), ('s25', MONTH(sbTxDateTime) >= 10:numeric), ('s26', False:bool), ('s27', True:bool), ('s28', NOT(ISIN(MONTH(sbTxDateTime), [1, 2, 3]:array[numeric]))), ('s29', NOT(ISIN(MONTH(sbTxDateTime), [4, 5, 6]:array[numeric]))), ('s30', NOT(ISIN(MONTH(sbTxDateTime), [7, 8, 9]:array[numeric]))), ('s31', NOT(ISIN(MONTH(sbTxDateTime), [10, 11, 12]:array[numeric]))), ('s32', True:bool), ('s33', 2024:numeric), ('s34', 3:numeric), ('s35', 8:numeric), ('s36', 13:numeric), ('s37', 12:numeric), ('s38', 45:numeric), ('s39', 59:numeric), ('s40', 2020:numeric), ('s41', 1:numeric), ('s42', 1:numeric), ('s43', 31:numeric), ('s44', 0:numeric), ('s45', 0:numeric), ('s46', 0:numeric), ('s47', 2023:numeric), ('s48', 3:numeric), ('s49', 7:numeric), ('s50', 4:numeric), ('s51', 6:numeric), ('s52', 55:numeric), ('s53', 0:numeric), ('s54', 1999:numeric), ('s55', 4:numeric), ('s56', 12:numeric), ('s57', 31:numeric), ('s58', 23:numeric), ('s59', 59:numeric), ('s60', 58:numeric), ('s61', False:bool), ('s62', False:bool), ('s63', False:bool), ('s64', True:bool), ('s65', True:bool), ('s66', True:bool), ('s67', False:bool), ('s68', False:bool), ('s69', False:bool), ('s70', True:bool), ('s71', True:bool), ('s72', True:bool), ('s73', False:bool), ('s74', False:bool), ('s75', True:bool), ('s76', True:bool), ('s77', False:bool), ('s78', True:bool), ('s79', False:bool), ('s80', True:bool), ('s81', True:bool), ('s82', False:bool), ('s83', True:bool), ('s84', False:bool)], orderings=[]) FILTER(condition=RANKING(args=[], partition=[], order=[(sbTxDateTime):asc_last]) == 1:numeric | RANKING(args=[], partition=[], order=[(sbTxDateTime):desc_first]) == 1:numeric, columns={'sbTxDateTime': sbTxDateTime}) FILTER(condition=YEAR(sbTxDateTime) == 2023:numeric, columns={'sbTxDateTime': sbTxDateTime}) SCAN(table=main.sbTransaction, columns={'sbTxDateTime': sbTxDateTime}) diff --git a/tests/test_plan_refsols/technograph_year_cumulative_incident_rate_overall.txt b/tests/test_plan_refsols/technograph_year_cumulative_incident_rate_overall.txt index 2f06f93d8..982989ffb 100644 --- a/tests/test_plan_refsols/technograph_year_cumulative_incident_rate_overall.txt +++ b/tests/test_plan_refsols/technograph_year_cumulative_incident_rate_overall.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('yr', year), ('cum_ir', ROUND(RELSUM(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last], cumulative=True) / RELSUM(args=[DEFAULT_TO(sum_expr_3, 0:numeric)], partition=[], order=[(year):asc_last], cumulative=True), 2:numeric)), ('pct_bought_change', ROUND(100.0:numeric * DEFAULT_TO(sum_expr_3, 0:numeric) - PREV(args=[DEFAULT_TO(sum_expr_3, 0:numeric)], partition=[], order=[(year):asc_last]) / PREV(args=[DEFAULT_TO(sum_expr_3, 0:numeric)], partition=[], order=[(year):asc_last]), 2:numeric)), ('pct_incident_change', ROUND(100.0:numeric * DEFAULT_TO(sum_n_rows, 0:numeric) - PREV(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last]) / PREV(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last]), 2:numeric)), ('bought', DEFAULT_TO(sum_expr_3, 0:numeric)), ('incidents', DEFAULT_TO(sum_n_rows, 0:numeric))], orderings=[(year):asc_first]) +ROOT(columns=[('yr', year), ('cum_ir', ROUND(RELSUM(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last], cumulative=True) / RELSUM(args=[sum_expr_3], partition=[], order=[(year):asc_last], cumulative=True), 2:numeric)), ('pct_bought_change', ROUND(100.0:numeric * sum_expr_3 - PREV(args=[sum_expr_3], partition=[], order=[(year):asc_last]) / PREV(args=[sum_expr_3], partition=[], order=[(year):asc_last]), 2:numeric)), ('pct_incident_change', ROUND(100.0:numeric * DEFAULT_TO(sum_n_rows, 0:numeric) - PREV(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last]) / PREV(args=[DEFAULT_TO(sum_n_rows, 0:numeric)], partition=[], order=[(year):asc_last]), 2:numeric)), ('bought', sum_expr_3), ('incidents', DEFAULT_TO(sum_n_rows, 0:numeric))], orderings=[(year):asc_first]) FILTER(condition=DEFAULT_TO(sum_expr_3, 0:numeric) > 0:numeric, columns={'sum_expr_3': sum_expr_3, 'sum_n_rows': sum_n_rows, 'year': year}) AGGREGATE(keys={'year': YEAR(ca_dt)}, aggregations={'sum_expr_3': SUM(expr_3), 'sum_n_rows': SUM(n_rows)}) JOIN(condition=t0.ca_dt == t1.ca_dt, type=LEFT, cardinality=SINGULAR_FILTER, columns={'ca_dt': t0.ca_dt, 'expr_3': t0.n_rows, 'n_rows': t1.n_rows}) diff --git a/tests/test_plan_refsols/tpch_q18.txt b/tests/test_plan_refsols/tpch_q18.txt index 8acc50868..6964b7e62 100644 --- a/tests/test_plan_refsols/tpch_q18.txt +++ b/tests/test_plan_refsols/tpch_q18.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('C_NAME', c_name), ('C_CUSTKEY', c_custkey), ('O_ORDERKEY', o_orderkey), ('O_ORDERDATE', o_orderdate), ('O_TOTALPRICE', o_totalprice), ('TOTAL_QUANTITY', DEFAULT_TO(sum_l_quantity, 0:numeric))], orderings=[(o_totalprice):desc_last, (o_orderdate):asc_first], limit=10:numeric) +ROOT(columns=[('C_NAME', c_name), ('C_CUSTKEY', c_custkey), ('O_ORDERKEY', o_orderkey), ('O_ORDERDATE', o_orderdate), ('O_TOTALPRICE', o_totalprice), ('TOTAL_QUANTITY', sum_l_quantity)], orderings=[(o_totalprice):desc_last, (o_orderdate):asc_first], limit=10:numeric) JOIN(condition=t0.o_orderkey == t1.l_orderkey, type=INNER, cardinality=SINGULAR_FILTER, columns={'c_custkey': t0.c_custkey, 'c_name': t0.c_name, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'o_totalprice': t0.o_totalprice, 'sum_l_quantity': t1.sum_l_quantity}) JOIN(condition=t0.o_custkey == t1.c_custkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'c_custkey': t1.c_custkey, 'c_name': t1.c_name, 'o_orderdate': t0.o_orderdate, 'o_orderkey': t0.o_orderkey, 'o_totalprice': t0.o_totalprice}) SCAN(table=tpch.ORDERS, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_totalprice': o_totalprice}) diff --git a/tests/test_sql_refsols/defog_broker_gen2_ansi.sql b/tests/test_sql_refsols/defog_broker_gen2_ansi.sql index f66b2462d..c584e8ca5 100644 --- a/tests/test_sql_refsols/defog_broker_gen2_ansi.sql +++ b/tests/test_sql_refsols/defog_broker_gen2_ansi.sql @@ -1,5 +1,5 @@ SELECT - COUNT(sbtransaction.sbtxcustid) AS transaction_count + COUNT(*) AS transaction_count FROM main.sbtransaction AS sbtransaction JOIN main.sbcustomer AS sbcustomer ON sbcustomer.sbcustid = sbtransaction.sbtxcustid diff --git a/tests/test_sql_refsols/defog_broker_gen2_sqlite.sql b/tests/test_sql_refsols/defog_broker_gen2_sqlite.sql index 1300381dd..d50d0b734 100644 --- a/tests/test_sql_refsols/defog_broker_gen2_sqlite.sql +++ b/tests/test_sql_refsols/defog_broker_gen2_sqlite.sql @@ -1,5 +1,5 @@ SELECT - COUNT(sbtransaction.sbtxcustid) AS transaction_count + COUNT(*) AS transaction_count FROM main.sbtransaction AS sbtransaction JOIN main.sbcustomer AS sbcustomer ON sbcustomer.sbcustid = sbtransaction.sbtxcustid diff --git a/tests/test_sql_refsols/defog_dealership_gen4_ansi.sql b/tests/test_sql_refsols/defog_dealership_gen4_ansi.sql index 6ac06680b..ec5326977 100644 --- a/tests/test_sql_refsols/defog_dealership_gen4_ansi.sql +++ b/tests/test_sql_refsols/defog_dealership_gen4_ansi.sql @@ -24,7 +24,7 @@ WITH _s0 AS ( SELECT quarter, state AS customer_state, - COALESCE(sum_sum_sale_price, 0) AS total_sales + sum_sum_sale_price AS total_sales FROM _t1 WHERE NOT sum_sum_sale_price IS NULL AND sum_sum_sale_price > 0 diff --git a/tests/test_sql_refsols/defog_dealership_gen4_sqlite.sql b/tests/test_sql_refsols/defog_dealership_gen4_sqlite.sql index 96ad10d92..f69abfc36 100644 --- a/tests/test_sql_refsols/defog_dealership_gen4_sqlite.sql +++ b/tests/test_sql_refsols/defog_dealership_gen4_sqlite.sql @@ -40,7 +40,7 @@ WITH _s0 AS ( SELECT quarter, state AS customer_state, - COALESCE(sum_sum_sale_price, 0) AS total_sales + sum_sum_sale_price AS total_sales FROM _t1 WHERE NOT sum_sum_sale_price IS NULL AND sum_sum_sale_price > 0 diff --git a/tests/test_sql_refsols/defog_ewallet_adv11_ansi.sql b/tests/test_sql_refsols/defog_ewallet_adv11_ansi.sql index d9bb1d546..f56490262 100644 --- a/tests/test_sql_refsols/defog_ewallet_adv11_ansi.sql +++ b/tests/test_sql_refsols/defog_ewallet_adv11_ansi.sql @@ -12,9 +12,9 @@ WITH _s1 AS ( ) SELECT users.uid, - COALESCE(_s1.sum_duration, 0) AS total_duration + _s1.sum_duration AS total_duration FROM main.users AS users JOIN _s1 AS _s1 ON _s1.user_id = users.uid ORDER BY - COALESCE(_s1.sum_duration, 0) DESC + _s1.sum_duration DESC diff --git a/tests/test_sql_refsols/defog_ewallet_adv11_sqlite.sql b/tests/test_sql_refsols/defog_ewallet_adv11_sqlite.sql index c8a248291..407137c3b 100644 --- a/tests/test_sql_refsols/defog_ewallet_adv11_sqlite.sql +++ b/tests/test_sql_refsols/defog_ewallet_adv11_sqlite.sql @@ -18,9 +18,9 @@ WITH _s1 AS ( ) SELECT users.uid, - COALESCE(_s1.sum_duration, 0) AS total_duration + _s1.sum_duration AS total_duration FROM main.users AS users JOIN _s1 AS _s1 ON _s1.user_id = users.uid ORDER BY - COALESCE(_s1.sum_duration, 0) DESC + _s1.sum_duration DESC diff --git a/tests/test_sql_refsols/simple_smallest_or_largest_sqlite.sql b/tests/test_sql_refsols/simple_smallest_or_largest_sqlite.sql index 9d7de5d70..ca3a12291 100644 --- a/tests/test_sql_refsols/simple_smallest_or_largest_sqlite.sql +++ b/tests/test_sql_refsols/simple_smallest_or_largest_sqlite.sql @@ -3,19 +3,19 @@ SELECT MIN(20, 20) AS s2, MIN(20, 10, 0) AS s3, MIN(20, 10, 10, -1, -2, 100, -200) AS s4, - MIN(20, 10, NULL, 100, 200) AS s5, + NULL AS s5, MIN(20.22, 10.22, -0.34) AS s6, MIN('2025-01-01 00:00:00', '2024-01-01 00:00:00', '2023-01-01 00:00:00') AS s7, MIN('', 'alphabet soup', 'Hello World') AS s8, - MIN(NULL, 'alphabet soup', 'Hello World') AS s9, + NULL AS s9, MAX(20, 10) AS l1, MAX(20, 20) AS l2, MAX(20, 10, 0) AS l3, MAX(20, 10, 10, -1, -2, 100, -200, 300) AS l4, - MAX(20, 10, NULL, 100, 200) AS l5, + NULL AS l5, MAX(20.22, 100.22, -0.34) AS l6, MAX('2025-01-01 00:00:00', '2024-01-01 00:00:00', '2023-01-01 00:00:00') AS l7, MAX('', 'alphabet soup', 'Hello World') AS l8, - MAX(NULL, 'alphabet soup', 'Hello World') AS l9 + NULL AS l9 FROM (VALUES (NULL)) AS _q_0 diff --git a/tests/test_sql_refsols/simplification_4_ansi.sql b/tests/test_sql_refsols/simplification_4_ansi.sql index 0b06f8cb0..7489179ad 100644 --- a/tests/test_sql_refsols/simplification_4_ansi.sql +++ b/tests/test_sql_refsols/simplification_4_ansi.sql @@ -11,38 +11,38 @@ WITH _t1 AS ( SELECT sbtxdatetime AS date_time, DATE_ADD(DATE_TRUNC('WEEK', CAST(sbtxdatetime AS TIMESTAMP)), -8, 'WEEK') AS s00, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s01, + FALSE AS s01, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (1, 2, 3) AS s02, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (4, 5, 6) AS s03, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (7, 8, 9) AS s04, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (10, 11, 12) AS s05, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s06, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s07, + FALSE AS s06, + FALSE AS s07, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 4 AS s08, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 7 AS s09, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) < 10 AS s10, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s11, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s12, + TRUE AS s11, + FALSE AS s12, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 3 AS s13, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 6 AS s14, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) <= 9 AS s15, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s16, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s17, + TRUE AS s16, + TRUE AS s17, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 3 AS s18, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 6 AS s19, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) > 9 AS s20, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s21, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s22, + FALSE AS s21, + TRUE AS s22, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 4 AS s23, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 7 AS s24, EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) >= 10 AS s25, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s26, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s27, + FALSE AS s26, + TRUE AS s27, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (1, 2, 3) AS s28, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (4, 5, 6) AS s29, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (7, 8, 9) AS s30, NOT EXTRACT(MONTH FROM CAST(sbtxdatetime AS DATETIME)) IN (10, 11, 12) AS s31, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32, + TRUE AS s32, 2024 AS s33, 3 AS s34, 8 AS s35, @@ -71,28 +71,28 @@ SELECT 23 AS s58, 59 AS s59, 58 AS s60, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s61, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s62, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s63, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s64, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s65, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s66, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s67, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s68, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s69, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s70, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s71, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s72, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s73, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s74, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s75, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s76, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s77, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s78, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s79, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s80, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s81, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s82, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s83, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s84 + FALSE AS s61, + FALSE AS s62, + FALSE AS s63, + TRUE AS s64, + TRUE AS s65, + TRUE AS s66, + FALSE AS s67, + FALSE AS s68, + FALSE AS s69, + TRUE AS s70, + TRUE AS s71, + TRUE AS s72, + FALSE AS s73, + FALSE AS s74, + TRUE AS s75, + TRUE AS s76, + FALSE AS s77, + TRUE AS s78, + FALSE AS s79, + TRUE AS s80, + TRUE AS s81, + FALSE AS s82, + TRUE AS s83, + FALSE AS s84 FROM _t1 diff --git a/tests/test_sql_refsols/simplification_4_sqlite.sql b/tests/test_sql_refsols/simplification_4_sqlite.sql index b73a8da1b..913af5338 100644 --- a/tests/test_sql_refsols/simplification_4_sqlite.sql +++ b/tests/test_sql_refsols/simplification_4_sqlite.sql @@ -17,38 +17,38 @@ SELECT 'start of day', '-56 day' ) AS s00, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s01, + FALSE AS s01, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (1, 2, 3) AS s02, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (4, 5, 6) AS s03, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (7, 8, 9) AS s04, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (10, 11, 12) AS s05, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s06, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s07, + FALSE AS s06, + FALSE AS s07, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 4 AS s08, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 7 AS s09, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) < 10 AS s10, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s11, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s12, + TRUE AS s11, + FALSE AS s12, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 3 AS s13, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 6 AS s14, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) <= 9 AS s15, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s16, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s17, + TRUE AS s16, + TRUE AS s17, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 3 AS s18, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 6 AS s19, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) > 9 AS s20, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s21, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s22, + FALSE AS s21, + TRUE AS s22, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 4 AS s23, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 7 AS s24, CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) >= 10 AS s25, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s26, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s27, + FALSE AS s26, + TRUE AS s27, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (1, 2, 3) AS s28, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (4, 5, 6) AS s29, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (7, 8, 9) AS s30, NOT CAST(STRFTIME('%m', sbtxdatetime) AS INTEGER) IN (10, 11, 12) AS s31, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s32, + TRUE AS s32, 2024 AS s33, 3 AS s34, 8 AS s35, @@ -77,30 +77,30 @@ SELECT 23 AS s58, 59 AS s59, 58 AS s60, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s61, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s62, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s63, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s64, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s65, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s66, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s67, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s68, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s69, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s70, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s71, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s72, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s73, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s74, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s75, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s76, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s77, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s78, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s79, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s80, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s81, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s82, - CASE WHEN NOT sbtxdatetime IS NULL THEN TRUE ELSE NULL END AS s83, - CASE WHEN NOT sbtxdatetime IS NULL THEN FALSE ELSE NULL END AS s84 + FALSE AS s61, + FALSE AS s62, + FALSE AS s63, + TRUE AS s64, + TRUE AS s65, + TRUE AS s66, + FALSE AS s67, + FALSE AS s68, + FALSE AS s69, + TRUE AS s70, + TRUE AS s71, + TRUE AS s72, + FALSE AS s73, + FALSE AS s74, + TRUE AS s75, + TRUE AS s76, + FALSE AS s77, + TRUE AS s78, + FALSE AS s79, + TRUE AS s80, + TRUE AS s81, + FALSE AS s82, + TRUE AS s83, + FALSE AS s84 FROM _t WHERE _w = 1 OR _w_2 = 1 diff --git a/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_ansi.sql b/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_ansi.sql index 461f9fb2c..9da85ba7b 100644 --- a/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_ansi.sql +++ b/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_ansi.sql @@ -36,15 +36,15 @@ WITH _s2 AS ( SELECT year AS yr, ROUND( - SUM(COALESCE(sum_n_rows, 0)) OVER (ORDER BY year NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) / SUM(COALESCE(sum_expr_3, 0)) OVER (ORDER BY year NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + SUM(COALESCE(sum_n_rows, 0)) OVER (ORDER BY year NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) / SUM(sum_expr_3) OVER (ORDER BY year NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2 ) AS cum_ir, ROUND( ( 100.0 * ( - COALESCE(sum_expr_3, 0) - LAG(COALESCE(sum_expr_3, 0), 1) OVER (ORDER BY year NULLS LAST) + sum_expr_3 - LAG(sum_expr_3, 1) OVER (ORDER BY year NULLS LAST) ) - ) / LAG(COALESCE(sum_expr_3, 0), 1) OVER (ORDER BY year NULLS LAST), + ) / LAG(sum_expr_3, 1) OVER (ORDER BY year NULLS LAST), 2 ) AS pct_bought_change, ROUND( @@ -55,7 +55,7 @@ SELECT ) / LAG(COALESCE(sum_n_rows, 0), 1) OVER (ORDER BY year NULLS LAST), 2 ) AS pct_incident_change, - COALESCE(sum_expr_3, 0) AS bought, + sum_expr_3 AS bought, COALESCE(sum_n_rows, 0) AS incidents FROM _t1 WHERE diff --git a/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_sqlite.sql b/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_sqlite.sql index d0661f3f0..36c66a2fe 100644 --- a/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_sqlite.sql +++ b/tests/test_sql_refsols/technograph_year_cumulative_incident_rate_overall_sqlite.sql @@ -36,15 +36,15 @@ WITH _s2 AS ( SELECT year AS yr, ROUND( - CAST(SUM(COALESCE(sum_n_rows, 0)) OVER (ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS REAL) / SUM(COALESCE(sum_expr_3, 0)) OVER (ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + CAST(SUM(COALESCE(sum_n_rows, 0)) OVER (ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS REAL) / SUM(sum_expr_3) OVER (ORDER BY year ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2 ) AS cum_ir, ROUND( CAST(( 100.0 * ( - COALESCE(sum_expr_3, 0) - LAG(COALESCE(sum_expr_3, 0), 1) OVER (ORDER BY year) + sum_expr_3 - LAG(sum_expr_3, 1) OVER (ORDER BY year) ) - ) AS REAL) / LAG(COALESCE(sum_expr_3, 0), 1) OVER (ORDER BY year), + ) AS REAL) / LAG(sum_expr_3, 1) OVER (ORDER BY year), 2 ) AS pct_bought_change, ROUND( @@ -55,7 +55,7 @@ SELECT ) AS REAL) / LAG(COALESCE(sum_n_rows, 0), 1) OVER (ORDER BY year), 2 ) AS pct_incident_change, - COALESCE(sum_expr_3, 0) AS bought, + sum_expr_3 AS bought, COALESCE(sum_n_rows, 0) AS incidents FROM _t1 WHERE diff --git a/tests/test_sql_refsols/tpch_q18_ansi.sql b/tests/test_sql_refsols/tpch_q18_ansi.sql index aa9134752..236ee09f0 100644 --- a/tests/test_sql_refsols/tpch_q18_ansi.sql +++ b/tests/test_sql_refsols/tpch_q18_ansi.sql @@ -12,7 +12,7 @@ SELECT orders.o_orderkey AS O_ORDERKEY, orders.o_orderdate AS O_ORDERDATE, orders.o_totalprice AS O_TOTALPRICE, - COALESCE(_t1.sum_l_quantity, 0) AS TOTAL_QUANTITY + _t1.sum_l_quantity AS TOTAL_QUANTITY FROM tpch.orders AS orders JOIN tpch.customer AS customer ON customer.c_custkey = orders.o_custkey diff --git a/tests/test_sql_refsols/tpch_q18_sqlite.sql b/tests/test_sql_refsols/tpch_q18_sqlite.sql index aa9134752..236ee09f0 100644 --- a/tests/test_sql_refsols/tpch_q18_sqlite.sql +++ b/tests/test_sql_refsols/tpch_q18_sqlite.sql @@ -12,7 +12,7 @@ SELECT orders.o_orderkey AS O_ORDERKEY, orders.o_orderdate AS O_ORDERDATE, orders.o_totalprice AS O_TOTALPRICE, - COALESCE(_t1.sum_l_quantity, 0) AS TOTAL_QUANTITY + _t1.sum_l_quantity AS TOTAL_QUANTITY FROM tpch.orders AS orders JOIN tpch.customer AS customer ON customer.c_custkey = orders.o_custkey From 0a5402fc0ede6b4c5d35ba93ab7af7880f40eef4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 14:17:30 -0400 Subject: [PATCH 06/16] Fixing KEEP_IF bug [RUN CI] --- .../conversion/relational_simplification.py | 13 ++++++++--- tests/test_plan_refsols/dumb_aggregation.txt | 2 +- .../agg_simplification_1_sqlite.sql | 22 ++----------------- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index f4e6a6110..f2bb93b19 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -1127,6 +1127,7 @@ def simplify_function_call( # KEEP_IF(x, True) -> x # KEEP_IF(x, False) -> None + # KEEP_IF(None, y) -> None case pydop.KEEP_IF: if isinstance(expr.inputs[1], LiteralExpression): if bool(expr.inputs[1].value): @@ -1135,13 +1136,19 @@ def simplify_function_call( else: output_expr = LiteralExpression(None, expr.data_type) output_predicates.not_negative = True + elif ( + isinstance(expr.inputs[0], LiteralExpression) + and expr.inputs[0].value is None + ): + output_expr = LiteralExpression(None, expr.data_type) elif arg_predicates[1].not_null and arg_predicates[1].positive: output_expr = expr.inputs[0] output_predicates = arg_predicates[0] else: - output_predicates |= arg_predicates[0] & PredicateSet( - not_null=True, not_negative=True - ) + # Otherwise the predicates are the same as the first + # argument, except it can be null. + output_predicates |= arg_predicates[0] + output_predicates.not_null = False # DATETIME(DATETIME(u, v, w), x, y, z) -> DATETIME(u, v, w, x, y, z) case pydop.DATETIME: diff --git a/tests/test_plan_refsols/dumb_aggregation.txt b/tests/test_plan_refsols/dumb_aggregation.txt index 5b9eefb5e..3aea0cf43 100644 --- a/tests/test_plan_refsols/dumb_aggregation.txt +++ b/tests/test_plan_refsols/dumb_aggregation.txt @@ -1,4 +1,4 @@ -ROOT(columns=[('nation_name', n_name), ('a1', r_name), ('a2', r_name), ('a3', r_regionkey), ('a4', 1:numeric), ('a5', 1:numeric), ('a6', r_regionkey), ('a7', r_name), ('a8', r_regionkey)], orderings=[(n_name):asc_first]) +ROOT(columns=[('nation_name', n_name), ('a1', r_name), ('a2', r_name), ('a3', r_regionkey), ('a4', IFF(PRESENT(KEEP_IF(r_regionkey, r_name != 'AMERICA':string)), 1:numeric, 0:numeric)), ('a5', 1:numeric), ('a6', r_regionkey), ('a7', r_name), ('a8', r_regionkey)], orderings=[(n_name):asc_first]) JOIN(condition=t0.n_regionkey == t1.r_regionkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'n_name': t0.n_name, 'r_name': t1.r_name, 'r_regionkey': t1.r_regionkey}) LIMIT(limit=2:numeric, columns={'n_name': n_name, 'n_regionkey': n_regionkey}, orderings=[(n_name):asc_first]) SCAN(table=tpch.NATION, columns={'n_name': n_name, 'n_regionkey': n_regionkey}) diff --git a/tests/test_sql_refsols/agg_simplification_1_sqlite.sql b/tests/test_sql_refsols/agg_simplification_1_sqlite.sql index 2de50de10..5f460e32f 100644 --- a/tests/test_sql_refsols/agg_simplification_1_sqlite.sql +++ b/tests/test_sql_refsols/agg_simplification_1_sqlite.sql @@ -78,19 +78,6 @@ WITH _t1 AS ( THEN 0.5 ELSE NULL END AS expr_77, - CASE - WHEN ABS( - ( - ROW_NUMBER() OVER (PARTITION BY LENGTH(CASE WHEN sbtickerexchange <> 'NYSE Arca' THEN sbtickerexchange ELSE NULL END) ORDER BY NULL DESC) - 1.0 - ) - ( - CAST(( - COUNT(NULL) OVER (PARTITION BY LENGTH(CASE WHEN sbtickerexchange <> 'NYSE Arca' THEN sbtickerexchange ELSE NULL END)) - 1.0 - ) AS REAL) / 2.0 - ) - ) < 1.0 - THEN NULL - ELSE NULL - END AS expr_78, CASE WHEN ABS( ( @@ -136,11 +123,6 @@ WITH _t1 AS ( THEN 0.5 ELSE NULL END AS expr_85, - CASE - WHEN CAST(0.30000000000000004 * COUNT(NULL) OVER (PARTITION BY LENGTH(CASE WHEN sbtickerexchange <> 'NYSE Arca' THEN sbtickerexchange ELSE NULL END)) AS INTEGER) < ROW_NUMBER() OVER (PARTITION BY LENGTH(CASE WHEN sbtickerexchange <> 'NYSE Arca' THEN sbtickerexchange ELSE NULL END) ORDER BY NULL DESC) - THEN NULL - ELSE NULL - END AS expr_86, CASE WHEN CAST(0.19999999999999996 * COUNT( LENGTH(CASE WHEN sbtickerexchange <> 'NYSE Arca' THEN sbtickerexchange ELSE NULL END) @@ -226,7 +208,7 @@ SELECT AVG(expr_75) AS me4, AVG(expr_76) AS me5, AVG(expr_77) AS me6, - AVG(expr_78) AS me7, + NULL AS me7, AVG(expr_79) AS me8, MAX(expr_80) AS qu1, MAX(expr_81) AS qu2, @@ -234,7 +216,7 @@ SELECT MAX(expr_83) AS qu4, MAX(expr_84) AS qu5, MAX(expr_85) AS qu6, - MAX(expr_86) AS qu7, + NULL AS qu7, MAX(expr_87) AS qu8 FROM _t1 GROUP BY From 41373f197c193af56de27f753e303cc52856b77f Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Fri, 15 Aug 2025 14:32:58 -0400 Subject: [PATCH 07/16] Adding documentation --- .../conversion/relational_simplification.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index f2bb93b19..82ce7ceac 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -1371,22 +1371,44 @@ def infer_null_predicates_from_condition( columns: dict[str, RelationalExpression], ) -> None: """ - TODO + Infers whether an output column can be marked as not-null based on the + given condition expression. If the condition implies that a column is + not null, the corresponding PredicateSet in output_predicates is updated + in-place. + + Args: + `output_predicates`: A dictionary mapping each output column + reference from the current node to the set of its inferred + predicates. + `condition`: The condition expression from the current node (e.g. a + filter or an inner/semi join) which, if false when a certain column + is null, means that column can be marked as not-null in the output. + `columns`: A dictionary mapping column names to their corresponding + relational expressions in the current node. """ from .filter_pushdown import NullReplacementShuttle self.shuttle.input_predicates = {} + # Iterate across all of the output columns that are not already marked + # as not-null and identify the ones that correspond to a column + # reference passed through from the input node. for expr, preds in output_predicates.items(): if preds.not_null: continue if isinstance(expr, ColumnReference) and expr.name in columns: expr = columns[expr.name] if isinstance(expr, ColumnReference): + # Transform the condition by creating a version where the + # input column is replaced with a NULL literal, and then run + # the simplifier on the new expression. shuttle: NullReplacementShuttle = NullReplacementShuttle( {expr.name} ) new_cond: RelationalExpression = condition.accept_shuttle(shuttle) new_cond = new_cond.accept_shuttle(self.shuttle) + # If the new condition simplifies to a False-y literal, then + # the column must be not-null since it means that if the + # column were, the row would be filtered out. if isinstance(new_cond, LiteralExpression) and not bool( new_cond.value ): From 19ac3269f4a6ab555a424092537d2267cbbfc8ab Mon Sep 17 00:00:00 2001 From: knassre-bodo <105652923+knassre-bodo@users.noreply.github.com> Date: Fri, 22 Aug 2025 18:14:12 -0400 Subject: [PATCH 08/16] Update pydough/conversion/relational_simplification.py Co-authored-by: john-sanchez31 --- pydough/conversion/relational_simplification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 82ce7ceac..bc75c739b 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -396,7 +396,7 @@ def simplify_function_literal_comparison( case (pydop.NEQ, pydop.QUARTER, NumericType()) if isinstance( lit_expr.value, int ): - # QUARTER(x) == 4 <=> NOT(ISIN(MONTH(x), [10, 11, 12])) +# QUARTER(x) != 4 <=> NOT(ISIN(MONTH(x), [10, 11, 12])) if lit_expr.value in (1, 2, 3, 4): result = CallExpression( pydop.NOT, From 3b98ae24725337e6bb3a983acd424a9aa6ee3003 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 22:14:24 +0000 Subject: [PATCH 09/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pydough/conversion/relational_simplification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index bc75c739b..c792b77a7 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -396,7 +396,7 @@ def simplify_function_literal_comparison( case (pydop.NEQ, pydop.QUARTER, NumericType()) if isinstance( lit_expr.value, int ): -# QUARTER(x) != 4 <=> NOT(ISIN(MONTH(x), [10, 11, 12])) + # QUARTER(x) != 4 <=> NOT(ISIN(MONTH(x), [10, 11, 12])) if lit_expr.value in (1, 2, 3, 4): result = CallExpression( pydop.NOT, From 7a6ed03cdb083c4ac22b93693935dc7a168919a4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 25 Aug 2025 11:50:23 -0400 Subject: [PATCH 10/16] Adding DOW/DAYNAME simplification [RUN CI] [RUN MYSQL] --- pydough/conversion/filter_pushdown.py | 10 ++- pydough/conversion/relational_converter.py | 6 +- .../conversion/relational_simplification.py | 74 ++++++++++++++----- .../all_pydough_functions_dialects.py | 15 +++- .../datetime_functions_ansi.sql | 51 ++++--------- .../datetime_functions_mysql.sql | 27 ++++--- .../datetime_functions_sqlite.sql | 51 ++++--------- 7 files changed, 126 insertions(+), 108 deletions(-) diff --git a/pydough/conversion/filter_pushdown.py b/pydough/conversion/filter_pushdown.py index c08be33d2..868fa9d59 100644 --- a/pydough/conversion/filter_pushdown.py +++ b/pydough/conversion/filter_pushdown.py @@ -6,6 +6,7 @@ import pydough.pydough_operators as pydop +from pydough.configs import PyDoughConfigs from pydough.relational import ( Aggregate, CallExpression, @@ -65,7 +66,7 @@ class FilterPushdownShuttle(RelationalShuttle): cannot be pushed further. """ - def __init__(self): + def __init__(self, configs: PyDoughConfigs): # The set of filters that are currently being pushed down. When # visit_xxx is called, it is presumed that the set of conditions in # self.filters are the conditions that can be pushed down as far as the @@ -75,7 +76,7 @@ def __init__(self): # simplification logic to aid in advanced filter predicate inference, # such as determining that a left join is redundant because if the RHS # column is null then the filter will always be false. - self.simplifier: SimplificationShuttle = SimplificationShuttle() + self.simplifier: SimplificationShuttle = SimplificationShuttle(configs) def reset(self): self.filters = set() @@ -299,12 +300,13 @@ def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> RelationalNo return self.flush_remaining_filters(empty_singleton, self.filters, set()) -def push_filters(node: RelationalNode) -> RelationalNode: +def push_filters(node: RelationalNode, configs: PyDoughConfigs) -> RelationalNode: """ Transpose filter conditions down as far as possible. Args: `node`: The current node of the relational tree. + `configs`: The PyDough configuration settings. Returns: The transformed version of `node` and all of its descendants with @@ -312,5 +314,5 @@ def push_filters(node: RelationalNode) -> RelationalNode: the node or into one of its inputs, or possibly both if there are multiple filters. """ - pusher: FilterPushdownShuttle = FilterPushdownShuttle() + pusher: FilterPushdownShuttle = FilterPushdownShuttle(configs) return node.accept_shuttle(pusher) diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index df8a5284a..5bf0d082d 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -1451,7 +1451,7 @@ def optimize_relational_tree( root = ColumnPruner().prune_unused_columns(root) # Step 1: push filters down as far as possible - root = confirm_root(push_filters(root)) + root = confirm_root(push_filters(root, configs)) # Step 2: merge adjacent projections, unless it would result in excessive # duplicate subexpression computations. @@ -1492,8 +1492,8 @@ def optimize_relational_tree( # pullup and pushdown and so on. for _ in range(2): root = confirm_root(pullup_projections(root)) - simplify_expressions(root, additional_shuttles) - root = confirm_root(push_filters(root)) + simplify_expressions(root, configs, additional_shuttles) + root = confirm_root(push_filters(root, configs)) root = ColumnPruner().prune_unused_columns(root) # Step 9: re-run projection merging, without pushing into joins. This diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 82ce7ceac..f742fbe91 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -15,6 +15,7 @@ import pandas as pd import pydough.pydough_operators as pydop +from pydough.configs import PyDoughConfigs from pydough.relational import ( Aggregate, CallExpression, @@ -38,7 +39,7 @@ from pydough.relational.rel_util import ( add_input_name, ) -from pydough.types import ArrayType, NumericType +from pydough.types import ArrayType, NumericType, StringType @dataclass @@ -205,10 +206,11 @@ class SimplificationShuttle(RelationalExpressionShuttle): simplifying their inputs and placing their predicate sets on the stack. """ - def __init__(self): + def __init__(self, configs: PyDoughConfigs): self.stack: list[PredicateSet] = [] self._input_predicates: dict[RelationalExpression, PredicateSet] = {} self._no_group_aggregate: bool = False + self._configs: PyDoughConfigs = configs @property def input_predicates(self) -> dict[RelationalExpression, PredicateSet]: @@ -238,6 +240,20 @@ def no_group_aggregate(self, value: bool) -> None: """ self._no_group_aggregate = value + @property + def configs(self) -> PyDoughConfigs: + """ + Returns the PyDough configuration settings. + """ + return self._configs + + @configs.setter + def configs(self, value: PyDoughConfigs) -> None: + """ + Sets the PyDough configuration settings. + """ + self._configs = value + def reset(self) -> None: self.stack = [] @@ -613,41 +629,52 @@ def simplify_datetime_literal_part( # where the literal is a native Python datetime/date, a pandas # Timestamp, or a string without any alphabetic characters (to avoid # parsing things like 'now' that depend on the current date). - ts: pd.Timestamp | None = None + timestamp_value: pd.Timestamp | None = None if isinstance(lit_expr.value, datetime.date): - ts = pd.Timestamp(lit_expr.value) + timestamp_value = pd.Timestamp(lit_expr.value) elif isinstance(lit_expr.value, str) and not any( c.isalpha() for c in lit_expr.value ): try: - ts = pd.Timestamp(lit_expr.value) + timestamp_value = pd.Timestamp(lit_expr.value) except Exception: return expr elif isinstance(lit_expr.value, pd.Timestamp): - ts = lit_expr.value + timestamp_value = lit_expr.value # Fall back to the original expression by default. - if ts is None: + if timestamp_value is None: return expr # Otherwise, extract the relevant part from the timestamp and return it # as a literal. match op: case pydop.YEAR: - return LiteralExpression(ts.year, NumericType()) + return LiteralExpression(timestamp_value.year, NumericType()) case pydop.QUARTER: - quarter: int = ((ts.month - 1) // 3) + 1 + quarter: int = ((timestamp_value.month - 1) // 3) + 1 return LiteralExpression(quarter, NumericType()) case pydop.MONTH: - return LiteralExpression(ts.month, NumericType()) + return LiteralExpression(timestamp_value.month, NumericType()) case pydop.DAY: - return LiteralExpression(ts.day, NumericType()) + return LiteralExpression(timestamp_value.day, NumericType()) case pydop.HOUR: - return LiteralExpression(ts.hour, NumericType()) + return LiteralExpression(timestamp_value.hour, NumericType()) case pydop.MINUTE: - return LiteralExpression(ts.minute, NumericType()) + return LiteralExpression(timestamp_value.minute, NumericType()) case pydop.SECOND: - return LiteralExpression(ts.second, NumericType()) + return LiteralExpression(timestamp_value.second, NumericType()) + case pydop.DAYNAME: + return LiteralExpression(timestamp_value.day_name(), StringType()) + case pydop.DAYOFWEEK: + # Derive the day of week as an integer, adjusting based on the + # configured start of the week. + dow: int = timestamp_value.weekday() + dow -= self.configs.start_of_week.pandas_dow + dow %= 7 + if not self.configs.start_week_as_zero: + dow += 1 + return LiteralExpression(dow, NumericType()) case _: return expr @@ -1163,7 +1190,8 @@ def simplify_function_call( ) # YEAR(literal_datetime) -> can infer the year as a literal - # (same for QUARTER, MONTH, DAY, HOUR, MINUTE, SECOND) + # (same for QUARTER, MONTH, DAY, HOUR, MINUTE, SECOND, DAYOFWEEK, + # and DAYNAME) case ( pydop.YEAR | pydop.QUARTER @@ -1172,6 +1200,8 @@ def simplify_function_call( | pydop.HOUR | pydop.MINUTE | pydop.SECOND + | pydop.DAYOFWEEK + | pydop.DAYNAME ): if isinstance(expr.inputs[0], LiteralExpression): output_expr = self.simplify_datetime_literal_part( @@ -1265,9 +1295,13 @@ class SimplificationVisitor(RelationalVisitor): the current node are placed on the stack. """ - def __init__(self, additional_shuttles: list[RelationalExpressionShuttle]): + def __init__( + self, + configs: PyDoughConfigs, + additional_shuttles: list[RelationalExpressionShuttle], + ): self.stack: list[dict[RelationalExpression, PredicateSet]] = [] - self.shuttle: SimplificationShuttle = SimplificationShuttle() + self.shuttle: SimplificationShuttle = SimplificationShuttle(configs) self.additional_shuttles: list[RelationalExpressionShuttle] = ( additional_shuttles ) @@ -1500,6 +1534,7 @@ def visit_aggregate(self, node: Aggregate) -> None: def simplify_expressions( node: RelationalNode, + configs: PyDoughConfigs, additional_shuttles: list[RelationalExpressionShuttle], ) -> None: """ @@ -1508,10 +1543,13 @@ def simplify_expressions( Args: `node`: The relational node to perform simplification on. + `configs`: The PyDough configuration settings. `additional_shuttles`: A list of additional shuttles to apply to the expressions of the node and its descendants. These shuttles are applied after the simplification shuttle, and can be used to perform additional transformations on the expressions. """ - simplifier: SimplificationVisitor = SimplificationVisitor(additional_shuttles) + simplifier: SimplificationVisitor = SimplificationVisitor( + configs, additional_shuttles + ) node.accept(simplifier) diff --git a/tests/test_pydough_functions/all_pydough_functions_dialects.py b/tests/test_pydough_functions/all_pydough_functions_dialects.py index b5fd4dec7..73b3d9bab 100644 --- a/tests/test_pydough_functions/all_pydough_functions_dialects.py +++ b/tests/test_pydough_functions/all_pydough_functions_dialects.py @@ -129,11 +129,22 @@ def datetime_functions(): dd_dt_str=DATEDIFF("weeks", "1992-01-01", specific_dt), # DAYOFWEEK / DAYNAME: all types dow_col=DAYOFWEEK(order_date), - dow_str=DAYOFWEEK("1992-07-01"), + dow_str1=DAYOFWEEK("1992-07-01"), + dow_str2=DAYOFWEEK("1992-07-02"), + dow_str3=DAYOFWEEK("1992-07-03"), + dow_str4=DAYOFWEEK("1992-07-04"), + dow_str5=DAYOFWEEK("1992-07-05"), + dow_str6=DAYOFWEEK("1992-07-06"), + dow_str7=DAYOFWEEK("1992-07-07"), dow_dt=DAYOFWEEK(specific_dt), dow_pd=DAYOFWEEK(today), dayname_col=DAYNAME(order_date), - dayname_str=DAYNAME("1995-06-30"), + dayname_str1=DAYNAME("1995-06-26"), + dayname_str2=DAYNAME("1995-06-27"), + dayname_str3=DAYNAME("1995-06-28"), + dayname_str4=DAYNAME("1995-06-29"), + dayname_str5=DAYNAME("1995-06-30"), + dayname_str6=DAYNAME("1995-07-01"), dayname_dt=DAYNAME(datetime.datetime(1993, 8, 15)), ) diff --git a/tests/test_sql_refsols/datetime_functions_ansi.sql b/tests/test_sql_refsols/datetime_functions_ansi.sql index 906e8ef1f..bd7889853 100644 --- a/tests/test_sql_refsols/datetime_functions_ansi.sql +++ b/tests/test_sql_refsols/datetime_functions_ansi.sql @@ -22,9 +22,15 @@ SELECT DATEDIFF(CAST('1992-01-01 12:30:45' AS TIMESTAMP), CAST(o_orderdate AS DATETIME), YEAR) AS dd_col_dt, DATEDIFF(CAST('1992-01-01 12:30:45' AS TIMESTAMP), CAST('1992-01-01' AS TIMESTAMP), WEEK) AS dd_dt_str, DAY_OF_WEEK(o_orderdate) AS dow_col, - DAY_OF_WEEK('1992-07-01') AS dow_str, - DAY_OF_WEEK(CAST('1992-01-01 12:30:45' AS TIMESTAMP)) AS dow_dt, - DAY_OF_WEEK(CAST('1995-10-10 00:00:00' AS TIMESTAMP)) AS dow_pd, + 3 AS dow_str1, + 4 AS dow_str2, + 5 AS dow_str3, + 6 AS dow_str4, + 0 AS dow_str5, + 1 AS dow_str6, + 2 AS dow_str7, + 3 AS dow_dt, + 2 AS dow_pd, CASE WHEN DAY_OF_WEEK(o_orderdate) = 0 THEN 'Sunday' @@ -41,36 +47,11 @@ SELECT WHEN DAY_OF_WEEK(o_orderdate) = 6 THEN 'Saturday' END AS dayname_col, - CASE - WHEN DAY_OF_WEEK('1995-06-30') = 0 - THEN 'Sunday' - WHEN DAY_OF_WEEK('1995-06-30') = 1 - THEN 'Monday' - WHEN DAY_OF_WEEK('1995-06-30') = 2 - THEN 'Tuesday' - WHEN DAY_OF_WEEK('1995-06-30') = 3 - THEN 'Wednesday' - WHEN DAY_OF_WEEK('1995-06-30') = 4 - THEN 'Thursday' - WHEN DAY_OF_WEEK('1995-06-30') = 5 - THEN 'Friday' - WHEN DAY_OF_WEEK('1995-06-30') = 6 - THEN 'Saturday' - END AS dayname_str, - CASE - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 0 - THEN 'Sunday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 1 - THEN 'Monday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 2 - THEN 'Tuesday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 3 - THEN 'Wednesday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 4 - THEN 'Thursday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 5 - THEN 'Friday' - WHEN DAY_OF_WEEK(CAST('1993-08-15 00:00:00' AS TIMESTAMP)) = 6 - THEN 'Saturday' - END AS dayname_dt + 'Monday' AS dayname_str1, + 'Tuesday' AS dayname_str2, + 'Wednesday' AS dayname_str3, + 'Thursday' AS dayname_str4, + 'Friday' AS dayname_str5, + 'Saturday' AS dayname_str6, + 'Sunday' AS dayname_dt FROM tpch.orders diff --git a/tests/test_sql_refsols/datetime_functions_mysql.sql b/tests/test_sql_refsols/datetime_functions_mysql.sql index 5a4f7035d..c1e82108c 100644 --- a/tests/test_sql_refsols/datetime_functions_mysql.sql +++ b/tests/test_sql_refsols/datetime_functions_mysql.sql @@ -49,16 +49,21 @@ SELECT ( DAYOFWEEK(o_orderdate) + -1 ) % 7 AS dow_col, - ( - DAYOFWEEK(CAST('1992-07-01' AS DATE)) + -1 - ) % 7 AS dow_str, - ( - DAYOFWEEK(CAST('1992-01-01 12:30:45' AS DATETIME)) + -1 - ) % 7 AS dow_dt, - ( - DAYOFWEEK(CAST('1995-10-10 00:00:00' AS DATETIME)) + -1 - ) % 7 AS dow_pd, + 3 AS dow_str1, + 4 AS dow_str2, + 5 AS dow_str3, + 6 AS dow_str4, + 0 AS dow_str5, + 1 AS dow_str6, + 2 AS dow_str7, + 3 AS dow_dt, + 2 AS dow_pd, DAYNAME(o_orderdate) AS dayname_col, - DAYNAME('1995-06-30') AS dayname_str, - DAYNAME(CAST('1993-08-15 00:00:00' AS DATETIME)) AS dayname_dt + 'Monday' AS dayname_str1, + 'Tuesday' AS dayname_str2, + 'Wednesday' AS dayname_str3, + 'Thursday' AS dayname_str4, + 'Friday' AS dayname_str5, + 'Saturday' AS dayname_str6, + 'Sunday' AS dayname_dt FROM tpch.ORDERS diff --git a/tests/test_sql_refsols/datetime_functions_sqlite.sql b/tests/test_sql_refsols/datetime_functions_sqlite.sql index bf725c4f6..cbb6648fc 100644 --- a/tests/test_sql_refsols/datetime_functions_sqlite.sql +++ b/tests/test_sql_refsols/datetime_functions_sqlite.sql @@ -42,9 +42,15 @@ SELECT ) ) AS INTEGER) AS REAL) / 7 AS INTEGER) AS dd_dt_str, CAST(STRFTIME('%w', o_orderdate) AS INTEGER) AS dow_col, - CAST(STRFTIME('%w', DATETIME('1992-07-01')) AS INTEGER) AS dow_str, - CAST(STRFTIME('%w', '1992-01-01 12:30:45') AS INTEGER) AS dow_dt, - CAST(STRFTIME('%w', '1995-10-10 00:00:00') AS INTEGER) AS dow_pd, + 3 AS dow_str1, + 4 AS dow_str2, + 5 AS dow_str3, + 6 AS dow_str4, + 0 AS dow_str5, + 1 AS dow_str6, + 2 AS dow_str7, + 3 AS dow_dt, + 2 AS dow_pd, CASE WHEN CAST(STRFTIME('%w', o_orderdate) AS INTEGER) = 0 THEN 'Sunday' @@ -61,36 +67,11 @@ SELECT WHEN CAST(STRFTIME('%w', o_orderdate) AS INTEGER) = 6 THEN 'Saturday' END AS dayname_col, - CASE - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 0 - THEN 'Sunday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 1 - THEN 'Monday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 2 - THEN 'Tuesday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 3 - THEN 'Wednesday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 4 - THEN 'Thursday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 5 - THEN 'Friday' - WHEN CAST(STRFTIME('%w', DATETIME('1995-06-30')) AS INTEGER) = 6 - THEN 'Saturday' - END AS dayname_str, - CASE - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 0 - THEN 'Sunday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 1 - THEN 'Monday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 2 - THEN 'Tuesday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 3 - THEN 'Wednesday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 4 - THEN 'Thursday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 5 - THEN 'Friday' - WHEN CAST(STRFTIME('%w', '1993-08-15 00:00:00') AS INTEGER) = 6 - THEN 'Saturday' - END AS dayname_dt + 'Monday' AS dayname_str1, + 'Tuesday' AS dayname_str2, + 'Wednesday' AS dayname_str3, + 'Thursday' AS dayname_str4, + 'Friday' AS dayname_str5, + 'Saturday' AS dayname_str6, + 'Sunday' AS dayname_dt FROM tpch.orders From 4bf584c62440c4a5d8a6aa34845a48d541391c00 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 25 Aug 2025 11:50:45 -0400 Subject: [PATCH 11/16] Adding DOW/DAYNAME simplification [RUN CI] [RUN MYSQL] From 66a19e911fb9268af24651102da5cacbf0c9c67a Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 26 Aug 2025 16:18:36 -0400 Subject: [PATCH 12/16] Add logic for simplifying literals that are masked --- pydough/conversion/masking_shuttles.py | 73 ++++++++++++++----- pydough/sqlglot/__init__.py | 2 + .../cryptbank_filter_count_01.txt | 2 +- .../cryptbank_filter_count_01_sqlite.sql | 2 +- 4 files changed, 58 insertions(+), 21 deletions(-) diff --git a/pydough/conversion/masking_shuttles.py b/pydough/conversion/masking_shuttles.py index 8cdfbb261..3dd2f403d 100644 --- a/pydough/conversion/masking_shuttles.py +++ b/pydough/conversion/masking_shuttles.py @@ -4,6 +4,9 @@ __all__ = ["MaskLiteralComparisonShuttle"] +from sqlglot import expressions as sqlglot_expressions +from sqlglot import parse_one + import pydough.pydough_operators as pydop from pydough.relational import ( CallExpression, @@ -11,6 +14,9 @@ RelationalExpression, RelationalExpressionShuttle, ) +from pydough.sqlglot import convert_sqlglot_to_relational + +from .relational_simplification import SimplificationShuttle class MaskLiteralComparisonShuttle(RelationalExpressionShuttle): @@ -18,15 +24,38 @@ class MaskLiteralComparisonShuttle(RelationalExpressionShuttle): TODO """ - def is_unprotect_call(self, expr: RelationalExpression) -> bool: + def __init__(self): + self.simplifier: SimplificationShuttle = SimplificationShuttle() + + def simplify_masked_literal( + self, value: RelationalExpression + ) -> RelationalExpression: """ TODO """ - return ( - isinstance(expr, CallExpression) - and isinstance(expr.op, pydop.MaskedExpressionFunctionOperator) - and expr.op.is_unprotect - ) + if ( + not isinstance(value, CallExpression) + or not isinstance(value.op, pydop.MaskedExpressionFunctionOperator) + or value.op.is_unprotect + or len(value.inputs) != 1 + or not isinstance(value.inputs[0], LiteralExpression) + ): + return value + try: + arg_sql_str: str = sqlglot_expressions.convert(value.inputs[0].value).sql() + total_sql_str: str = value.op.format_string.format(arg_sql_str) + glot_expr: sqlglot_expressions.Expression = parse_one(total_sql_str) + new_expr: RelationalExpression | None = convert_sqlglot_to_relational( + glot_expr + ) + if new_expr is not None: + return new_expr + self.simplifier.reset() + return value.accept_shuttle(self.simplifier) + except Exception: + return value + + return value def protect_literal_comparison( self, @@ -53,22 +82,23 @@ def protect_literal_comparison( call_arg.data_type, [literal_arg], ) + masked_literal = self.simplify_masked_literal(masked_literal) elif original_call.op == pydop.ISIN and isinstance( literal_arg.value, (list, tuple) ): - masked_literal = LiteralExpression( - [ - CallExpression( - pydop.MaskedExpressionFunctionOperator( - call_arg.op.masking_metadata, False - ), - call_arg.data_type, - [LiteralExpression(v, literal_arg.data_type)], - ) - for v in literal_arg.value - ], - original_call.data_type, - ) + new_elems: list[RelationalExpression] = [] + for val in literal_arg.value: + new_val: RelationalExpression = CallExpression( + pydop.MaskedExpressionFunctionOperator( + call_arg.op.masking_metadata, False + ), + call_arg.data_type, + [LiteralExpression(val, literal_arg.data_type)], + ) + new_val = self.simplify_masked_literal(new_val) + new_elems.append(new_val) + + masked_literal = LiteralExpression(new_elems, original_call.data_type) else: return original_call @@ -82,6 +112,8 @@ def visit_call_expression( self, call_expression: CallExpression ) -> RelationalExpression: if call_expression.op in (pydop.EQU, pydop.NEQ): + # UNMASK(expr) = literal --> expr = MASK(literal) + # UNMASK(expr) != literal --> expr != MASK(literal) if isinstance(call_expression.inputs[0], CallExpression) and isinstance( call_expression.inputs[1], LiteralExpression ): @@ -90,6 +122,8 @@ def visit_call_expression( call_expression.inputs[0], call_expression.inputs[1], ) + # literal = UNMASK(expr) --> MASK(literal) = expr + # literal != UNMASK(expr) --> MASK(literal) != expr if isinstance(call_expression.inputs[1], CallExpression) and isinstance( call_expression.inputs[0], LiteralExpression ): @@ -98,6 +132,7 @@ def visit_call_expression( call_expression.inputs[1], call_expression.inputs[0], ) + # UNMASK(expr) IN (x, y, z) --> expr IN (MASK(x), MASK(y), MASK(z)) if ( call_expression.op == pydop.ISIN and isinstance(call_expression.inputs[0], CallExpression) diff --git a/pydough/sqlglot/__init__.py b/pydough/sqlglot/__init__.py index 4bdf4f9d1..66542bd6b 100644 --- a/pydough/sqlglot/__init__.py +++ b/pydough/sqlglot/__init__.py @@ -3,6 +3,7 @@ "SQLGlotRelationalVisitor", "convert_dialect_to_sqlglot", "convert_relation_to_sql", + "convert_sqlglot_to_relational", "execute_df", "find_identifiers", "find_identifiers_in_list", @@ -19,3 +20,4 @@ from .sqlglot_identifier_finder import find_identifiers, find_identifiers_in_list from .sqlglot_relational_expression_visitor import SQLGlotRelationalExpressionVisitor from .sqlglot_relational_visitor import SQLGlotRelationalVisitor +from .sqlglot_to_relational import convert_sqlglot_to_relational diff --git a/tests/test_plan_refsols/cryptbank_filter_count_01.txt b/tests/test_plan_refsols/cryptbank_filter_count_01.txt index 4e430683e..ecd4e9844 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_01.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_01.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_lname == MASK::(UPPER(['lee':string])), columns={}) + FILTER(condition=c_lname == 'LEE':string, columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) diff --git a/tests/test_sql_refsols/cryptbank_filter_count_01_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_01_sqlite.sql index bffd9c7c0..298dbab4e 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_01_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_01_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_lname = UPPER('lee') + c_lname = 'LEE' From df8e74197a745709936a01022ea169d09cd24deb Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 26 Aug 2025 18:02:18 -0400 Subject: [PATCH 13/16] Initial implementation of mask simplification by converting sqlglot back to relational --- .../conversion/relational_simplification.py | 12 +++ pydough/sqlglot/sqlglot_to_relational.py | 91 +++++++++++++++++++ .../cryptbank_filter_count_02.txt | 2 +- .../cryptbank_filter_count_03.txt | 2 +- .../cryptbank_filter_count_04.txt | 2 +- .../cryptbank_filter_count_11.txt | 2 +- .../cryptbank_filter_count_02_sqlite.sql | 2 +- .../cryptbank_filter_count_03_sqlite.sql | 2 +- .../cryptbank_filter_count_04_sqlite.sql | 2 +- .../cryptbank_filter_count_11_sqlite.sql | 3 +- 10 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 pydough/sqlglot/sqlglot_to_relational.py diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 10cecd60e..5a23d1943 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -236,6 +236,18 @@ def visit_literal_expression( output_predicates.not_negative = True if literal_expression.value > 0: output_predicates.positive = True + if isinstance(literal_expression.value, (list, tuple)): + new_elems: list = [] + for val in literal_expression.value: + if isinstance(val, RelationalExpression): + new_elems.append(val.accept_shuttle(self)) + self.stack.pop() + else: + new_elems.append(val) + if new_elems != literal_expression.value: + literal_expression = LiteralExpression( + new_elems, literal_expression.data_type + ) self.stack.append(output_predicates) return literal_expression diff --git a/pydough/sqlglot/sqlglot_to_relational.py b/pydough/sqlglot/sqlglot_to_relational.py new file mode 100644 index 000000000..69a40a7c0 --- /dev/null +++ b/pydough/sqlglot/sqlglot_to_relational.py @@ -0,0 +1,91 @@ +""" +TODO +""" + +__all__ = ["convert_sqlglot_to_relational"] + + +from sqlglot import expressions as sqlglot_expressions + +import pydough.pydough_operators as pydop +from pydough.relational import ( + CallExpression, + LiteralExpression, + RelationalExpression, +) +from pydough.types import ( + BooleanType, + DatetimeType, + NumericType, + StringType, + UnknownType, +) + + +class GlotRelFail(Exception): + """ + TODO + """ + + +def glot_to_rel(glot_expr: sqlglot_expressions.Expression) -> RelationalExpression: + """ + TODO + """ + + # Convert all of the sub-expressions to a relational expression. This + # step is stored in a "thunk" so it only happens under certain conditions. + def sub_rels() -> list[RelationalExpression]: + return [glot_to_rel(e) for e in glot_expr.iter_expressions()] + + match glot_expr: + case sqlglot_expressions.TimeStrToTime(): + return CallExpression( + pydop.DATETIME, + DatetimeType(), + sub_rels(), + ) + case sqlglot_expressions.Literal(): + literal_expr = glot_expr.to_py() + if isinstance(literal_expr, str): + return LiteralExpression(literal_expr, StringType()) + elif isinstance(literal_expr, bool): + return LiteralExpression(literal_expr, BooleanType()) + elif isinstance(literal_expr, int) or isinstance(literal_expr, float): + return LiteralExpression(literal_expr, NumericType()) + else: + raise GlotRelFail() + return LiteralExpression( + glot_expr.this, + StringType(), + ) + case sqlglot_expressions.Null(): + return LiteralExpression(None, UnknownType()) + case sqlglot_expressions.Lower(): + return CallExpression(pydop.LOWER, StringType(), sub_rels()) + case sqlglot_expressions.Upper(): + return CallExpression(pydop.UPPER, StringType(), sub_rels()) + case sqlglot_expressions.Length(): + return CallExpression(pydop.LENGTH, NumericType(), sub_rels()) + case sqlglot_expressions.Abs(): + return CallExpression(pydop.ABS, StringType(), sub_rels()) + case _: + raise GlotRelFail() + + +def convert_sqlglot_to_relational( + glot_expr: sqlglot_expressions.Expression, +) -> RelationalExpression | None: + """ + Attempt to convert a sqlglot expression to a relational expression. + + Args: + `glot_expr`: The sqlglot expression to convert. + + Returns: + The converted relational expression, or None if the attempt failed. + """ + try: + return glot_to_rel(glot_expr) + except GlotRelFail: + return None diff --git a/tests/test_plan_refsols/cryptbank_filter_count_02.txt b/tests/test_plan_refsols/cryptbank_filter_count_02.txt index b86695d95..14e5d7a40 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_02.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_02.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_lname != MASK::(UPPER(['lee':string])), columns={}) + FILTER(condition=c_lname != 'LEE':string, columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_03.txt b/tests/test_plan_refsols/cryptbank_filter_count_03.txt index 43d30e4a1..d39b01aaa 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_03.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_03.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=ISIN(c_lname, [Call(op=Function[MASK], inputs=[Literal(value='lee', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='smith', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='rodriguez', type=ArrayType(UnknownType()))], return_type=StringType())]:bool), columns={}) + FILTER(condition=ISIN(c_lname, [Literal(value='LEE', type=StringType()), Literal(value='SMITH', type=StringType()), Literal(value='RODRIGUEZ', type=StringType())]:bool), columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_04.txt b/tests/test_plan_refsols/cryptbank_filter_count_04.txt index c9f9fdf11..294c04d74 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_04.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_04.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=NOT(ISIN(c_lname, [Call(op=Function[MASK], inputs=[Literal(value='lee', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='smith', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='rodriguez', type=ArrayType(UnknownType()))], return_type=StringType())]:bool)), columns={}) + FILTER(condition=NOT(ISIN(c_lname, [Literal(value='LEE', type=StringType()), Literal(value='SMITH', type=StringType()), Literal(value='RODRIGUEZ', type=StringType())]:bool)), columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_11.txt b/tests/test_plan_refsols/cryptbank_filter_count_11.txt index 07d0a9df7..206cc74c1 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_11.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_11.txt @@ -4,5 +4,5 @@ ROOT(columns=[('n', n)], orderings=[]) SCAN(table=CRBNK.TRANSACTIONS, columns={'t_sourceaccount': t_sourceaccount}) JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_FILTER, columns={'a_key': t0.a_key}) SCAN(table=CRBNK.ACCOUNTS, columns={'a_custkey': a_custkey, 'a_key': a_key}) - FILTER(condition=c_fname == MASK::(UPPER(['alice':string])), columns={'c_key': c_key}) + FILTER(condition=c_fname == 'ALICE':string, columns={'c_key': c_key}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_fname': c_fname, 'c_key': c_key}) diff --git a/tests/test_sql_refsols/cryptbank_filter_count_02_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_02_sqlite.sql index f1f7b1c78..ec3a44be4 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_02_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_02_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_lname <> UPPER('lee') + c_lname <> 'LEE' diff --git a/tests/test_sql_refsols/cryptbank_filter_count_03_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_03_sqlite.sql index aa7550e49..a590ad01a 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_03_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_03_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_lname IN (UPPER('lee'), UPPER('smith'), UPPER('rodriguez')) + c_lname IN ('LEE', 'SMITH', 'RODRIGUEZ') diff --git a/tests/test_sql_refsols/cryptbank_filter_count_04_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_04_sqlite.sql index 6b329065c..5dc20fbac 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_04_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_04_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - NOT c_lname IN (UPPER('lee'), UPPER('smith'), UPPER('rodriguez')) + NOT c_lname IN ('LEE', 'SMITH', 'RODRIGUEZ') diff --git a/tests/test_sql_refsols/cryptbank_filter_count_11_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_11_sqlite.sql index 48071e73a..02cc1b528 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_11_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_11_sqlite.sql @@ -14,5 +14,4 @@ JOIN crbnk.accounts AS accounts JOIN crbnk.customers AS customers ON accounts.a_custkey = ( 42 - customers.c_key - ) - AND customers.c_fname = UPPER('alice') + ) AND customers.c_fname = 'ALICE' From 71236f10e3fc04f823fa99f7a315e10a03d6212b Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 26 Aug 2025 19:31:06 -0400 Subject: [PATCH 14/16] Datetime manipulation WIP --- pydough/sqlglot/sqlglot_to_relational.py | 7 +++++++ tests/test_plan_refsols/cryptbank_filter_count_08.txt | 2 +- tests/test_plan_refsols/cryptbank_filter_count_24.txt | 2 +- tests/test_plan_refsols/cryptbank_filter_count_25.txt | 2 +- .../test_sql_refsols/cryptbank_filter_count_08_sqlite.sql | 2 +- .../test_sql_refsols/cryptbank_filter_count_24_sqlite.sql | 2 +- .../test_sql_refsols/cryptbank_filter_count_25_sqlite.sql | 2 +- 7 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pydough/sqlglot/sqlglot_to_relational.py b/pydough/sqlglot/sqlglot_to_relational.py index 69a40a7c0..42c94a3c3 100644 --- a/pydough/sqlglot/sqlglot_to_relational.py +++ b/pydough/sqlglot/sqlglot_to_relational.py @@ -38,6 +38,7 @@ def glot_to_rel(glot_expr: sqlglot_expressions.Expression) -> RelationalExpressi def sub_rels() -> list[RelationalExpression]: return [glot_to_rel(e) for e in glot_expr.iter_expressions()] + flushed_args: list[RelationalExpression] match glot_expr: case sqlglot_expressions.TimeStrToTime(): return CallExpression( @@ -69,6 +70,12 @@ def sub_rels() -> list[RelationalExpression]: return CallExpression(pydop.LENGTH, NumericType(), sub_rels()) case sqlglot_expressions.Abs(): return CallExpression(pydop.ABS, StringType(), sub_rels()) + case sqlglot_expressions.Datetime(): + return CallExpression(pydop.DATETIME, DatetimeType(), sub_rels()) + case sqlglot_expressions.Date(): + flushed_args = sub_rels() + flushed_args.append(LiteralExpression("start of day", StringType())) + return CallExpression(pydop.DATETIME, DatetimeType(), flushed_args) case _: raise GlotRelFail() diff --git a/tests/test_plan_refsols/cryptbank_filter_count_08.txt b/tests/test_plan_refsols/cryptbank_filter_count_08.txt index b873a9a59..845253916 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_08.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_08.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday == MASK::(DATE(['1985-04-12':string], '-472 days')), columns={}) + FILTER(condition=c_birthday == DATETIME('1985-04-12':string, '-472 days':string, 'start of day':string), columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_24.txt b/tests/test_plan_refsols/cryptbank_filter_count_24.txt index de787f3b0..ff82a8cc8 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_24.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_24.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday == MASK::(DATE(['1991-11-15':string], '-472 days')), columns={}) + FILTER(condition=c_birthday == DATETIME('1991-11-15':string, '-472 days':string, 'start of day':string), columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_25.txt b/tests/test_plan_refsols/cryptbank_filter_count_25.txt index 0c89de89c..1c00e189d 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_25.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_25.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday != MASK::(DATE(['1991-11-15':string], '-472 days')), columns={}) + FILTER(condition=c_birthday != DATETIME('1991-11-15':string, '-472 days':string, 'start of day':string), columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql index fc4234022..e45688058 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday = DATE('1985-04-12', '-472 days') + c_birthday = DATE(DATETIME('1985-04-12', '-472 day'), 'start of day') diff --git a/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql index 16e183baa..3a3e438d7 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday = DATE('1991-11-15', '-472 days') + c_birthday = DATE(DATETIME('1991-11-15', '-472 day'), 'start of day') diff --git a/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql index 1ce9daa75..ad9993371 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday <> DATE('1991-11-15', '-472 days') + c_birthday <> DATE(DATETIME('1991-11-15', '-472 day'), 'start of day') From e6251e2c3ff10588e79f170bbbc943598e4956fc Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 26 Aug 2025 20:10:25 -0400 Subject: [PATCH 15/16] Added DATETIME chain simplification --- .../conversion/relational_simplification.py | 173 ++++++++++++++++-- tests/test_pipeline_tpch_custom.py | 12 +- tests/test_plan_refsols/datetime_relative.txt | 2 +- .../quarter_function_test.txt | 2 +- tests/test_plan_refsols/smoke_b.txt | 2 +- .../datetime_functions_ansi.sql | 4 +- .../datetime_functions_mysql.sql | 12 +- .../datetime_functions_sqlite.sql | 4 +- .../datetime_sampler_ansi.sql | 18 +- .../datetime_sampler_mysql.sql | 39 +--- .../datetime_sampler_sqlite.sql | 24 +-- tests/test_sql_refsols/smoke_b_ansi.sql | 17 +- tests/test_sql_refsols/smoke_b_mysql.sql | 13 +- tests/test_sql_refsols/smoke_b_sqlite.sql | 13 +- 14 files changed, 210 insertions(+), 125 deletions(-) diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 02363eded..d88a84f9b 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -10,6 +10,7 @@ import datetime +import re from dataclasses import dataclass import pandas as pd @@ -39,6 +40,11 @@ from pydough.relational.rel_util import ( add_input_name, ) +from pydough.sqlglot.transform_bindings.sqlglot_transform_utils import ( + DateTimeUnit, + offset_pattern, + trunc_pattern, +) from pydough.types import ArrayType, NumericType, StringType @@ -602,6 +608,33 @@ def simplify_function_literal_comparison( pass return result + def get_timestamp_literal(self, expr: RelationalExpression) -> pd.Timestamp | None: + """ + Attempts to extract a pandas Timestamp from a literal expression. Does + not try to parse strings with alphabetic characters to avoid parsing + things like 'now' that depend on the current date. + + Args: + `expr`: The expression to extract the timestamp from. + + Returns: + A pandas Timestamp if the expression is a literal that can be + converted to a timestamp, otherwise None. + """ + if not isinstance(expr, LiteralExpression): + return None + if isinstance(expr.value, pd.Timestamp): + return expr.value + elif isinstance(expr.value, datetime.date): + return pd.Timestamp(expr.value) + elif isinstance(expr.value, str) and not any(c.isalpha() for c in expr.value): + try: + return pd.Timestamp(expr.value) + except Exception: + return None + else: + return None + def simplify_datetime_literal_part( self, expr: RelationalExpression, @@ -629,18 +662,7 @@ def simplify_datetime_literal_part( # where the literal is a native Python datetime/date, a pandas # Timestamp, or a string without any alphabetic characters (to avoid # parsing things like 'now' that depend on the current date). - timestamp_value: pd.Timestamp | None = None - if isinstance(lit_expr.value, datetime.date): - timestamp_value = pd.Timestamp(lit_expr.value) - elif isinstance(lit_expr.value, str) and not any( - c.isalpha() for c in lit_expr.value - ): - try: - timestamp_value = pd.Timestamp(lit_expr.value) - except Exception: - return expr - elif isinstance(lit_expr.value, pd.Timestamp): - timestamp_value = lit_expr.value + timestamp_value: pd.Timestamp | None = self.get_timestamp_literal(lit_expr) # Fall back to the original expression by default. if timestamp_value is None: @@ -678,6 +700,128 @@ def simplify_datetime_literal_part( case _: return expr + def compress_datetime_literal_chain( + self, expr: CallExpression + ) -> RelationalExpression: + """ + Attempts to compress a DATETIME(arg0, arg1, arg2, ...) function call + where arg0 is a timestamp literal and all other arguments are string + literals representing datetime modifiers (e.g. 'start of month', + '+3 days', etc). If successful, returns a LiteralExpression with the + resulting timestamp or date. If not successful, returns the original + expression. + + Args: + `expr`: The CallExpression representing the DATETIME function call. + Assumes all the arguments are literals. + + Returns: + A LiteralExpression with the resulting timestamp or date if + successful, otherwise the original expression. + """ + assert expr.op == pydop.DATETIME and len(expr.inputs) > 0 + + # Extract a pandas Timestamp from the first argument if possible. If + # not possible, return the original expression. + timestamp_value: pd.Timestamp | None = self.get_timestamp_literal( + expr.inputs[0] + ) + if timestamp_value is None: + return expr + + # Extract the raw string values from the remaining arguments. If any + # of them are not string literals, return the original expression. + raw_args: list[str] = [] + for arg in expr.inputs[1:]: + if isinstance(arg, LiteralExpression) and isinstance(arg.value, str): + raw_args.append(arg.value) + else: + return expr + + # Keep track of whether the final result should be returned as a date + # (i.e. without a time component) or as a timestamp. + return_as_date: bool = timestamp_value == timestamp_value.normalize() + + # Process each argument in order, applying truncations and offsets to + # the timestamp value as needed. If any argument is not recognized, + # return the original expression. + for raw_arg in raw_args: + amt: int + unit: DateTimeUnit | None + trunc_match: re.Match | None = trunc_pattern.fullmatch(raw_arg) + offset_match: re.Match | None = offset_pattern.fullmatch(raw_arg) + if trunc_match is not None: + # If the string is in the form `start of `, apply + # truncation. + unit = DateTimeUnit.from_string(str(trunc_match.group(1))) + if unit is None: + raise ValueError( + f"Unsupported DATETIME modifier string: {raw_arg!r}" + ) + match unit: + case DateTimeUnit.YEAR: + timestamp_value = timestamp_value.to_period("Y").to_timestamp() + return_as_date = True + case DateTimeUnit.QUARTER: + timestamp_value = timestamp_value.to_period("Q").to_timestamp() + return_as_date = True + case DateTimeUnit.MONTH: + timestamp_value = timestamp_value.to_period("M").to_timestamp() + return_as_date = True + case DateTimeUnit.DAY: + timestamp_value = timestamp_value.floor("d") + return_as_date = True + case DateTimeUnit.HOUR: + timestamp_value = timestamp_value.floor("h") + case DateTimeUnit.MINUTE: + timestamp_value = timestamp_value.floor("min") + case _: + # Doesn't support truncating to WEEK or SECOND in this + # simplification. + return expr + elif offset_match is not None: + # If the string is in the form `± `, apply an + # offset. + amt = int(offset_match.group(2)) + if str(offset_match.group(1)) == "-": + amt *= -1 + unit = DateTimeUnit.from_string(str(offset_match.group(3))) + if unit is None: + raise ValueError( + f"Unsupported DATETIME modifier string: {raw_arg!r}" + ) + match unit: + case DateTimeUnit.YEAR: + timestamp_value = timestamp_value + pd.DateOffset(years=amt) + case DateTimeUnit.QUARTER: + timestamp_value = timestamp_value + pd.DateOffset( + months=amt * 3 + ) + case DateTimeUnit.MONTH: + timestamp_value = timestamp_value + pd.DateOffset(months=amt) + case DateTimeUnit.WEEK: + timestamp_value = timestamp_value + pd.DateOffset(days=amt * 7) + case DateTimeUnit.DAY: + timestamp_value = timestamp_value + pd.DateOffset(days=amt) + case DateTimeUnit.HOUR: + timestamp_value = timestamp_value + pd.Timedelta(hours=amt) + return_as_date = False + case DateTimeUnit.MINUTE: + timestamp_value = timestamp_value + pd.Timedelta(minutes=amt) + return_as_date = False + case DateTimeUnit.SECOND: + timestamp_value = timestamp_value + pd.Timedelta(seconds=amt) + return_as_date = False + else: + return expr + + # Return the final timestamp as a literal expression, converting to a + # date if needed. + if return_as_date: + return LiteralExpression(timestamp_value.date(), expr.data_type) + else: + return LiteralExpression(timestamp_value, expr.data_type) + def simplify_function_call( self, expr: CallExpression, @@ -1188,6 +1332,11 @@ def simplify_function_call( expr.data_type, expr.inputs[0].inputs + expr.inputs[1:], ) + assert isinstance(output_expr, CallExpression) + if all( + isinstance(arg, LiteralExpression) for arg in output_expr.inputs + ): + output_expr = self.compress_datetime_literal_chain(output_expr) # YEAR(literal_datetime) -> can infer the year as a literal # (same for QUARTER, MONTH, DAY, HOUR, MINUTE, SECOND, DAYOFWEEK, diff --git a/tests/test_pipeline_tpch_custom.py b/tests/test_pipeline_tpch_custom.py index a2d33efb7..6b3df76fe 100644 --- a/tests/test_pipeline_tpch_custom.py +++ b/tests/test_pipeline_tpch_custom.py @@ -2319,14 +2319,14 @@ "chain3": ["2023-10-01"], "plus_1q": ["2023-04-15 12:30:45"], "plus_2q": ["2023-07-15 12:30:45"], - "plus_3q": ["2023-10-15 00:00:00"], + "plus_3q": ["2023-10-15"], "minus_1q": ["2022-10-15 12:30:45"], "minus_2q": ["2022-07-15 12:30:45"], - "minus_3q": ["2022-04-15 00:00:00"], - "syntax1": ["2023-08-15 00:00:00"], - "syntax2": ["2024-02-15 00:00:00"], - "syntax3": ["2024-08-15 00:00:00"], - "syntax4": ["2022-08-15 00:00:00"], + "minus_3q": ["2022-04-15"], + "syntax1": ["2023-08-15"], + "syntax2": ["2024-02-15"], + "syntax3": ["2024-08-15"], + "syntax4": ["2022-08-15"], "q_diff1": [1], "q_diff2": [2], "q_diff3": [3], diff --git a/tests/test_plan_refsols/datetime_relative.txt b/tests/test_plan_refsols/datetime_relative.txt index be99d1ef3..b4b24891f 100644 --- a/tests/test_plan_refsols/datetime_relative.txt +++ b/tests/test_plan_refsols/datetime_relative.txt @@ -1,3 +1,3 @@ -ROOT(columns=[('d1', DATETIME(o_orderdate, 'Start of Year':string)), ('d2', DATETIME(o_orderdate, 'START OF MONTHS':string)), ('d3', DATETIME(o_orderdate, '-11 years':string, '+9 months':string, ' - 7 DaYs ':string, '+5 h':string, '-3 minutes':string, '+1 second':string)), ('d4', DATETIME(Timestamp('2025-07-04 12:58:45'):datetime, 'start of hour':string)), ('d5', DATETIME(Timestamp('2025-07-04 12:58:45'):datetime, 'start of minute':string)), ('d6', DATETIME(Timestamp('2025-07-14 12:58:45'):datetime, '+ 1000000 seconds':string))], orderings=[(o_orderdate):asc_first]) +ROOT(columns=[('d1', DATETIME(o_orderdate, 'Start of Year':string)), ('d2', DATETIME(o_orderdate, 'START OF MONTHS':string)), ('d3', DATETIME(o_orderdate, '-11 years':string, '+9 months':string, ' - 7 DaYs ':string, '+5 h':string, '-3 minutes':string, '+1 second':string)), ('d4', Timestamp('2025-07-04 12:00:00'):datetime), ('d5', Timestamp('2025-07-04 12:58:00'):datetime), ('d6', Timestamp('2025-07-26 02:45:25'):datetime)], orderings=[(o_orderdate):asc_first]) LIMIT(limit=10:numeric, columns={'o_orderdate': o_orderdate}, orderings=[(o_custkey):asc_first, (o_orderdate):asc_first]) SCAN(table=tpch.ORDERS, columns={'o_custkey': o_custkey, 'o_orderdate': o_orderdate}) diff --git a/tests/test_plan_refsols/quarter_function_test.txt b/tests/test_plan_refsols/quarter_function_test.txt index fb423657e..36b26e4a3 100644 --- a/tests/test_plan_refsols/quarter_function_test.txt +++ b/tests/test_plan_refsols/quarter_function_test.txt @@ -1,2 +1,2 @@ -ROOT(columns=[('_expr0', 1:numeric), ('_expr1', 1:numeric), ('_expr2', 1:numeric), ('_expr3', 2:numeric), ('_expr4', 2:numeric), ('_expr5', 2:numeric), ('_expr6', 3:numeric), ('_expr7', 3:numeric), ('_expr8', 3:numeric), ('_expr9', 4:numeric), ('_expr10', 4:numeric), ('_expr11', 4:numeric), ('_expr12', 1:numeric), ('q1_jan', DATETIME('2023-01-15 12:30:45':string, 'start of quarter':string)), ('q1_feb', DATETIME('2023-02-28 12:30:45':string, 'start of quarter':string)), ('q1_mar', DATETIME('2023-03-31':string, 'start of quarter':string)), ('q2_apr', DATETIME('2023-04-01':string, 'start of quarter':string)), ('q2_may', DATETIME('2023-05-15 12:30:45':string, 'start of quarter':string)), ('q2_jun', DATETIME('2023-06-30 12:30:45':string, 'start of quarter':string)), ('q3_jul', DATETIME('2023-07-01 12:30:45':string, 'start of quarter':string)), ('q3_aug', DATETIME('2023-08-15':string, 'start of quarter':string)), ('q3_sep', DATETIME('2023-09-30':string, 'start of quarter':string)), ('q4_oct', DATETIME('2023-10-01':string, 'start of quarter':string)), ('q4_nov', DATETIME('2023-11-15':string, 'start of quarter':string)), ('q4_dec', DATETIME('2023-12-31':string, 'start of quarter':string)), ('ts_q1', DATETIME(Timestamp('2024-02-29 12:30:45'):datetime, 'start of quarter':string)), ('alias1', DATETIME('2023-05-15':string, 'START OF QUARTER':string)), ('alias2', DATETIME('2023-08-15':string, 'Start Of Quarter':string)), ('alias3', DATETIME('2023-11-15':string, '\n Start Of\tQuarter\n\n':string)), ('alias4', DATETIME('2023-02-15':string, '\tSTART\tOF\tquarter\t':string)), ('chain1', DATETIME('2023-05-15':string, 'start of quarter':string, '+1 day':string, '+2 hours':string)), ('chain2', DATETIME('2023-08-15':string, 'start of quarter':string, 'start of day':string)), ('chain3', DATETIME('2023-11-15':string, '-1 month':string, 'start of quarter':string)), ('plus_1q', DATETIME('2023-01-15 12:30:45':string, '+1 quarter':string)), ('plus_2q', DATETIME('2023-01-15 12:30:45':string, '+2 quarters':string)), ('plus_3q', DATETIME('2023-01-15':string, '+3 quarters':string)), ('minus_1q', DATETIME('2023-01-15 12:30:45':string, '-1 quarter':string)), ('minus_2q', DATETIME('2023-01-15 12:30:45':string, '-2 quarters':string)), ('minus_3q', DATETIME('2023-01-15':string, '-3 quarters':string)), ('syntax1', DATETIME('2023-05-15':string, ' +1 QUARTER ':string)), ('syntax2', DATETIME('2023-08-15':string, '+2 Q':string)), ('syntax3', DATETIME('2023-11-15':string, ' \n +\t3 \nQuarters \n\r ':string)), ('syntax4', DATETIME('2023-02-15':string, '\t-\t2\tq\t':string)), ('q_diff1', DATEDIFF('quarter':string, '2023-01-15':string, '2023-04-15':string)), ('q_diff2', DATEDIFF('quarter':string, '2023-01-15':string, '2023-07-15':string)), ('q_diff3', DATEDIFF('quarter':string, '2023-01-15':string, '2023-10-15':string)), ('q_diff4', DATEDIFF('quarter':string, '2023-01-15':string, '2023-12-31':string)), ('q_diff5', DATEDIFF('quarter':string, '2023-01-15':string, '2024-01-15':string)), ('q_diff6', DATEDIFF('quarter':string, '2023-01-15':string, '2024-04-15':string)), ('q_diff7', DATEDIFF('quarter':string, '2022-10-15':string, '2024-04-15':string)), ('q_diff8', DATEDIFF('quarter':string, '2020-01-01':string, '2025-01-01':string)), ('q_diff9', DATEDIFF('quarter':string, '2023-04-15':string, '2023-01-15':string)), ('q_diff10', DATEDIFF('quarter':string, '2024-01-15':string, '2023-01-15':string)), ('q_diff11', DATEDIFF('quarter':string, '2023-03-31':string, '2023-04-01':string)), ('q_diff12', DATEDIFF('quarter':string, '2023-12-31':string, '2024-01-01':string))], orderings=[]) +ROOT(columns=[('_expr0', 1:numeric), ('_expr1', 1:numeric), ('_expr2', 1:numeric), ('_expr3', 2:numeric), ('_expr4', 2:numeric), ('_expr5', 2:numeric), ('_expr6', 3:numeric), ('_expr7', 3:numeric), ('_expr8', 3:numeric), ('_expr9', 4:numeric), ('_expr10', 4:numeric), ('_expr11', 4:numeric), ('_expr12', 1:numeric), ('q1_jan', datetime.date(2023, 1, 1):datetime), ('q1_feb', datetime.date(2023, 1, 1):datetime), ('q1_mar', datetime.date(2023, 1, 1):datetime), ('q2_apr', datetime.date(2023, 4, 1):datetime), ('q2_may', datetime.date(2023, 4, 1):datetime), ('q2_jun', datetime.date(2023, 4, 1):datetime), ('q3_jul', datetime.date(2023, 7, 1):datetime), ('q3_aug', datetime.date(2023, 7, 1):datetime), ('q3_sep', datetime.date(2023, 7, 1):datetime), ('q4_oct', datetime.date(2023, 10, 1):datetime), ('q4_nov', datetime.date(2023, 10, 1):datetime), ('q4_dec', datetime.date(2023, 10, 1):datetime), ('ts_q1', datetime.date(2024, 1, 1):datetime), ('alias1', datetime.date(2023, 4, 1):datetime), ('alias2', datetime.date(2023, 7, 1):datetime), ('alias3', datetime.date(2023, 10, 1):datetime), ('alias4', datetime.date(2023, 1, 1):datetime), ('chain1', Timestamp('2023-04-02 02:00:00'):datetime), ('chain2', datetime.date(2023, 7, 1):datetime), ('chain3', datetime.date(2023, 10, 1):datetime), ('plus_1q', Timestamp('2023-04-15 12:30:45'):datetime), ('plus_2q', Timestamp('2023-07-15 12:30:45'):datetime), ('plus_3q', datetime.date(2023, 10, 15):datetime), ('minus_1q', Timestamp('2022-10-15 12:30:45'):datetime), ('minus_2q', Timestamp('2022-07-15 12:30:45'):datetime), ('minus_3q', datetime.date(2022, 4, 15):datetime), ('syntax1', datetime.date(2023, 8, 15):datetime), ('syntax2', datetime.date(2024, 2, 15):datetime), ('syntax3', datetime.date(2024, 8, 15):datetime), ('syntax4', datetime.date(2022, 8, 15):datetime), ('q_diff1', DATEDIFF('quarter':string, '2023-01-15':string, '2023-04-15':string)), ('q_diff2', DATEDIFF('quarter':string, '2023-01-15':string, '2023-07-15':string)), ('q_diff3', DATEDIFF('quarter':string, '2023-01-15':string, '2023-10-15':string)), ('q_diff4', DATEDIFF('quarter':string, '2023-01-15':string, '2023-12-31':string)), ('q_diff5', DATEDIFF('quarter':string, '2023-01-15':string, '2024-01-15':string)), ('q_diff6', DATEDIFF('quarter':string, '2023-01-15':string, '2024-04-15':string)), ('q_diff7', DATEDIFF('quarter':string, '2022-10-15':string, '2024-04-15':string)), ('q_diff8', DATEDIFF('quarter':string, '2020-01-01':string, '2025-01-01':string)), ('q_diff9', DATEDIFF('quarter':string, '2023-04-15':string, '2023-01-15':string)), ('q_diff10', DATEDIFF('quarter':string, '2024-01-15':string, '2023-01-15':string)), ('q_diff11', DATEDIFF('quarter':string, '2023-03-31':string, '2023-04-01':string)), ('q_diff12', DATEDIFF('quarter':string, '2023-12-31':string, '2024-01-01':string))], orderings=[]) EMPTYSINGLETON() diff --git a/tests/test_plan_refsols/smoke_b.txt b/tests/test_plan_refsols/smoke_b.txt index 10bfda8a7..5d56b5cc0 100644 --- a/tests/test_plan_refsols/smoke_b.txt +++ b/tests/test_plan_refsols/smoke_b.txt @@ -1,3 +1,3 @@ -ROOT(columns=[('key', o_orderkey), ('a', JOIN_STRINGS('_':string, YEAR(o_orderdate), QUARTER(o_orderdate), MONTH(o_orderdate), DAY(o_orderdate))), ('b', JOIN_STRINGS(':':string, DAYNAME(o_orderdate), DAYOFWEEK(o_orderdate))), ('c', DATETIME(o_orderdate, 'start of year':string, '+6 months':string, '-13 days':string)), ('d', DATETIME(o_orderdate, 'start of quarter':string, '+1 year':string, '+25 hours':string)), ('e', DATETIME('2025-01-01 12:35:13':string, 'start of minute':string)), ('f', DATETIME('2025-01-01 12:35:13':string, 'start of hour':string, '+2 quarters':string, '+3 weeks':string)), ('g', DATETIME('2025-01-01 12:35:13':string, 'start of day':string)), ('h', JOIN_STRINGS(';':string, 12:numeric, MINUTE(DATETIME('2025-01-01 12:35:13':string, '+45 minutes':string)), SECOND(DATETIME('2025-01-01 12:35:13':string, '-7 seconds':string)))), ('i', DATEDIFF('years':string, '1993-05-25 12:45:36':string, o_orderdate)), ('j', DATEDIFF('quarters':string, '1993-05-25 12:45:36':string, o_orderdate)), ('k', DATEDIFF('months':string, '1993-05-25 12:45:36':string, o_orderdate)), ('l', DATEDIFF('weeks':string, '1993-05-25 12:45:36':string, o_orderdate)), ('m', DATEDIFF('days':string, '1993-05-25 12:45:36':string, o_orderdate)), ('n', DATEDIFF('hours':string, '1993-05-25 12:45:36':string, o_orderdate)), ('o', DATEDIFF('minutes':string, '1993-05-25 12:45:36':string, o_orderdate)), ('p', DATEDIFF('seconds':string, '1993-05-25 12:45:36':string, o_orderdate)), ('q', DATETIME(o_orderdate, 'start of week':string))], orderings=[(o_orderkey):asc_first], limit=5:numeric) +ROOT(columns=[('key', o_orderkey), ('a', JOIN_STRINGS('_':string, YEAR(o_orderdate), QUARTER(o_orderdate), MONTH(o_orderdate), DAY(o_orderdate))), ('b', JOIN_STRINGS(':':string, DAYNAME(o_orderdate), DAYOFWEEK(o_orderdate))), ('c', DATETIME(o_orderdate, 'start of year':string, '+6 months':string, '-13 days':string)), ('d', DATETIME(o_orderdate, 'start of quarter':string, '+1 year':string, '+25 hours':string)), ('e', Timestamp('2025-01-01 12:35:00'):datetime), ('f', Timestamp('2025-07-22 12:00:00'):datetime), ('g', datetime.date(2025, 1, 1):datetime), ('h', JOIN_STRINGS(';':string, 12:numeric, 20:numeric, 6:numeric)), ('i', DATEDIFF('years':string, '1993-05-25 12:45:36':string, o_orderdate)), ('j', DATEDIFF('quarters':string, '1993-05-25 12:45:36':string, o_orderdate)), ('k', DATEDIFF('months':string, '1993-05-25 12:45:36':string, o_orderdate)), ('l', DATEDIFF('weeks':string, '1993-05-25 12:45:36':string, o_orderdate)), ('m', DATEDIFF('days':string, '1993-05-25 12:45:36':string, o_orderdate)), ('n', DATEDIFF('hours':string, '1993-05-25 12:45:36':string, o_orderdate)), ('o', DATEDIFF('minutes':string, '1993-05-25 12:45:36':string, o_orderdate)), ('p', DATEDIFF('seconds':string, '1993-05-25 12:45:36':string, o_orderdate)), ('q', DATETIME(o_orderdate, 'start of week':string))], orderings=[(o_orderkey):asc_first], limit=5:numeric) FILTER(condition=CONTAINS(o_comment, 'fo':string) & ENDSWITH(o_clerk, '5':string) & STARTSWITH(o_orderpriority, '3':string), columns={'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey}) SCAN(table=tpch.ORDERS, columns={'o_clerk': o_clerk, 'o_comment': o_comment, 'o_orderdate': o_orderdate, 'o_orderkey': o_orderkey, 'o_orderpriority': o_orderpriority}) diff --git a/tests/test_sql_refsols/datetime_functions_ansi.sql b/tests/test_sql_refsols/datetime_functions_ansi.sql index bd7889853..6143373b3 100644 --- a/tests/test_sql_refsols/datetime_functions_ansi.sql +++ b/tests/test_sql_refsols/datetime_functions_ansi.sql @@ -3,8 +3,8 @@ SELECT DATE_TRUNC('DAY', CURRENT_TIMESTAMP()) AS ts_now_2, DATE_TRUNC('MONTH', CURRENT_TIMESTAMP()) AS ts_now_3, DATE_ADD(CURRENT_TIMESTAMP(), 1, 'HOUR') AS ts_now_4, - CAST('2025-01-01 00:00:00' AS TIMESTAMP) AS ts_now_5, - CAST('1995-10-08 00:00:00' AS TIMESTAMP) AS ts_now_6, + CAST('2025-01-01' AS DATE) AS ts_now_5, + CAST('1995-10-08' AS DATE) AS ts_now_6, EXTRACT(YEAR FROM CAST(o_orderdate AS DATETIME)) AS year_col, 2020 AS year_py, 1995 AS year_pd, diff --git a/tests/test_sql_refsols/datetime_functions_mysql.sql b/tests/test_sql_refsols/datetime_functions_mysql.sql index c1e82108c..715edb748 100644 --- a/tests/test_sql_refsols/datetime_functions_mysql.sql +++ b/tests/test_sql_refsols/datetime_functions_mysql.sql @@ -6,16 +6,8 @@ SELECT '%Y %c %e' ) AS ts_now_3, DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL '1' HOUR) AS ts_now_4, - STR_TO_DATE( - CONCAT( - YEAR(CAST('2025-01-01 00:00:00' AS DATETIME)), - ' ', - MONTH(CAST('2025-01-01 00:00:00' AS DATETIME)), - ' 1' - ), - '%Y %c %e' - ) AS ts_now_5, - CAST('1995-10-08 00:00:00' AS DATETIME) AS ts_now_6, + CAST('2025-01-01' AS DATE) AS ts_now_5, + CAST('1995-10-08' AS DATE) AS ts_now_6, EXTRACT(YEAR FROM CAST(o_orderdate AS DATETIME)) AS year_col, 2020 AS year_py, 1995 AS year_pd, diff --git a/tests/test_sql_refsols/datetime_functions_sqlite.sql b/tests/test_sql_refsols/datetime_functions_sqlite.sql index cbb6648fc..54081fdc9 100644 --- a/tests/test_sql_refsols/datetime_functions_sqlite.sql +++ b/tests/test_sql_refsols/datetime_functions_sqlite.sql @@ -3,8 +3,8 @@ SELECT DATE('now', 'start of day') AS ts_now_2, DATE('now', 'start of month') AS ts_now_3, DATETIME('now', '1 hour') AS ts_now_4, - DATE('2025-01-01 00:00:00', 'start of month') AS ts_now_5, - DATETIME('1995-10-10 00:00:00', '-2 day') AS ts_now_6, + '2025-01-01' AS ts_now_5, + '1995-10-08' AS ts_now_6, CAST(STRFTIME('%Y', o_orderdate) AS INTEGER) AS year_col, 2020 AS year_py, 1995 AS year_pd, diff --git a/tests/test_sql_refsols/datetime_sampler_ansi.sql b/tests/test_sql_refsols/datetime_sampler_ansi.sql index 91760cd81..7c6431d45 100644 --- a/tests/test_sql_refsols/datetime_sampler_ansi.sql +++ b/tests/test_sql_refsols/datetime_sampler_ansi.sql @@ -1,8 +1,8 @@ SELECT CAST('2025-07-04 12:58:45' AS TIMESTAMP) AS _expr0, CAST('2024-12-31 11:59:00' AS TIMESTAMP) AS _expr1, - CAST('2025-01-01' AS TIMESTAMP) AS _expr2, - CAST('1999-03-14' AS TIMESTAMP) AS _expr3, + CAST('2025-01-01' AS DATE) AS _expr2, + CAST('1999-03-14' AS DATE) AS _expr3, CURRENT_TIMESTAMP() AS _expr4, CURRENT_TIMESTAMP() AS _expr5, CURRENT_TIMESTAMP() AS _expr6, @@ -50,14 +50,10 @@ SELECT 'SECOND' ) AS _expr34, DATE_TRUNC('DAY', CURRENT_TIMESTAMP()) AS _expr35, - DATE_ADD( - DATE_ADD(DATE_TRUNC('HOUR', CAST('2025-01-01' AS TIMESTAMP)), 49, 'MINUTE'), - 91, - 'YEAR' - ) AS _expr36, + CAST('2116-01-01 00:49:00' AS TIMESTAMP) AS _expr36, DATE_TRUNC('DAY', DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())) AS _expr37, DATE_TRUNC('YEAR', DATE_TRUNC('DAY', CURRENT_TIMESTAMP())) AS _expr38, - CAST('2025-07-01 13:20:45' AS TIMESTAMP) AS _expr39, + CAST('2025-07-01 00:22:00' AS TIMESTAMP) AS _expr39, DATE_TRUNC('YEAR', CURRENT_TIMESTAMP()) AS _expr40, DATE_TRUNC( 'YEAR', @@ -79,15 +75,15 @@ SELECT DATE_ADD(DATE_ADD(DATE_ADD(CURRENT_TIMESTAMP(), 297, 'DAY'), 72, 'MONTH'), -92, 'MONTH') ) AS _expr45, DATE_TRUNC('DAY', DATE_ADD(CURRENT_TIMESTAMP(), 285, 'SECOND')) AS _expr46, - CAST('1999-05-15 00:00:00' AS TIMESTAMP) AS _expr47, + CAST('1999-05-15' AS DATE) AS _expr47, DATE_ADD( DATE_TRUNC('MONTH', DATE_ADD(DATE_TRUNC('MONTH', CURRENT_TIMESTAMP()), 1, 'HOUR')), -21, 'DAY' ) AS _expr48, DATE_ADD(DATE_ADD(CURRENT_TIMESTAMP(), 212, 'MINUTE'), 368, 'YEAR') AS _expr49, - DATE_TRUNC('MINUTE', DATE_TRUNC('MINUTE', CAST('2024-01-01 11:59:00' AS TIMESTAMP))) AS _expr50, - DATE_TRUNC('DAY', DATE_TRUNC('HOUR', CAST('1999-03-14' AS TIMESTAMP))) AS _expr51, + CAST('2024-01-01' AS DATE) AS _expr50, + CAST('1999-03-14' AS DATE) AS _expr51, DATE_ADD( DATE_TRUNC('MINUTE', DATE_TRUNC('DAY', DATE_ADD(CURRENT_TIMESTAMP(), -60, 'HOUR'))), 196, diff --git a/tests/test_sql_refsols/datetime_sampler_mysql.sql b/tests/test_sql_refsols/datetime_sampler_mysql.sql index e3ffcde46..d7fd9523c 100644 --- a/tests/test_sql_refsols/datetime_sampler_mysql.sql +++ b/tests/test_sql_refsols/datetime_sampler_mysql.sql @@ -1,8 +1,8 @@ SELECT CAST('2025-07-04 12:58:45' AS DATETIME) AS _expr0, CAST('2024-12-31 11:59:00' AS DATETIME) AS _expr1, - CAST('2025-01-01' AS DATETIME) AS _expr2, - CAST('1999-03-14' AS DATETIME) AS _expr3, + CAST('2025-01-01' AS DATE) AS _expr2, + CAST('1999-03-14' AS DATE) AS _expr3, CURRENT_TIMESTAMP() AS _expr4, CURRENT_TIMESTAMP() AS _expr5, CURRENT_TIMESTAMP() AS _expr6, @@ -92,18 +92,7 @@ SELECT CAST('2116-01-01 00:49:00' AS DATETIME) AS _expr36, STR_TO_DATE(CONCAT(YEAR(CURRENT_TIMESTAMP()), ' 1 1'), '%Y %c %e') AS _expr37, STR_TO_DATE(CONCAT(YEAR(CAST(CURRENT_TIMESTAMP() AS DATE)), ' 1 1'), '%Y %c %e') AS _expr38, - DATE_ADD( - CAST(STR_TO_DATE( - CONCAT( - YEAR(CAST('2025-07-04 12:58:45' AS DATETIME)), - ' ', - MONTH(CAST('2025-07-04 12:58:45' AS DATETIME)), - ' 1' - ), - '%Y %c %e' - ) AS DATETIME), - INTERVAL '22' MINUTE - ) AS _expr39, + CAST('2025-07-01 00:22:00' AS DATETIME) AS _expr39, STR_TO_DATE(CONCAT(YEAR(CURRENT_TIMESTAMP()), ' 1 1'), '%Y %c %e') AS _expr40, STR_TO_DATE( CONCAT( @@ -130,7 +119,7 @@ SELECT ) ) AS _expr45, CAST(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL '285' SECOND) AS DATE) AS _expr46, - CAST('1999-05-15 00:00:00' AS DATETIME) AS _expr47, + CAST('1999-05-15' AS DATE) AS _expr47, DATE_ADD( STR_TO_DATE( CONCAT( @@ -160,24 +149,8 @@ SELECT INTERVAL '-21' DAY ) AS _expr48, DATE_ADD(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL '212' MINUTE), INTERVAL '368' YEAR) AS _expr49, - STR_TO_DATE( - CONCAT( - YEAR( - STR_TO_DATE( - CONCAT( - YEAR(CAST('2024-12-31 11:59:00' AS DATETIME)), - ' ', - MONTH(CAST('2024-12-31 11:59:00' AS DATETIME)), - ' 1' - ), - '%Y %c %e' - ) - ), - ' 1 1' - ), - '%Y %c %e' - ) AS _expr50, - DATE(CAST('1999-03-14' AS DATETIME)) AS _expr51, + CAST('2024-01-01' AS DATE) AS _expr50, + CAST('1999-03-14' AS DATE) AS _expr51, DATE_ADD( CAST(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL '-60' HOUR) AS DATE), INTERVAL '196' YEAR diff --git a/tests/test_sql_refsols/datetime_sampler_sqlite.sql b/tests/test_sql_refsols/datetime_sampler_sqlite.sql index 3454cd89a..fa1ef0b9f 100644 --- a/tests/test_sql_refsols/datetime_sampler_sqlite.sql +++ b/tests/test_sql_refsols/datetime_sampler_sqlite.sql @@ -1,8 +1,8 @@ SELECT - DATETIME('2025-07-04 12:58:45') AS _expr0, - DATETIME('2024-12-31 11:59:00') AS _expr1, - DATETIME('2025-01-01') AS _expr2, - DATETIME('1999-03-14') AS _expr3, + '2025-07-04 12:58:45' AS _expr0, + '2024-12-31 11:59:00' AS _expr1, + '2025-01-01' AS _expr2, + '1999-03-14' AS _expr3, DATETIME('now') AS _expr4, DATETIME('now') AS _expr5, DATETIME('now') AS _expr6, @@ -41,10 +41,10 @@ SELECT DATETIME(DATE('now', 'start of month'), '213 second') AS _expr33, DATETIME(DATE('now', 'start of month'), '13 minute', '28 year', '344 second') AS _expr34, DATE('now', 'start of day') AS _expr35, - DATETIME(STRFTIME('%Y-%m-%d %H:00:00', DATETIME('2025-01-01')), '49 minute', '91 year') AS _expr36, + '2116-01-01 00:49:00' AS _expr36, DATE('now', 'start of year', 'start of day') AS _expr37, DATE('now', 'start of day', 'start of year') AS _expr38, - DATETIME(DATE('2025-07-04 12:58:45', 'start of month'), '22 minute') AS _expr39, + '2025-07-01 00:22:00' AS _expr39, DATE('now', 'start of year') AS _expr40, DATE(DATETIME(o_orderdate, '82 second', '415 second', '-160 second'), 'start of year') AS _expr41, DATETIME('now', '192 month') AS _expr42, @@ -58,17 +58,11 @@ SELECT DATETIME(STRFTIME('%Y-%m-%d %H:%M:%S', DATETIME('now')), '-50 hour') AS _expr44, STRFTIME('%Y-%m-%d %H:00:00', DATETIME('now', '297 day', '72 month', '-92 month')) AS _expr45, DATE(DATETIME('now', '285 second'), 'start of day') AS _expr46, - DATETIME('1999-03-14', '62 day') AS _expr47, + '1999-05-15' AS _expr47, DATE(DATETIME(DATE('now', 'start of month'), '1 hour'), 'start of month', '-21 day') AS _expr48, DATETIME('now', '212 minute', '368 year') AS _expr49, - STRFTIME( - '%Y-%m-%d %H:%M:00', - STRFTIME( - '%Y-%m-%d %H:%M:00', - DATE('2024-12-31 11:59:00', 'start of month', 'start of year') - ) - ) AS _expr50, - DATE(STRFTIME('%Y-%m-%d %H:00:00', DATETIME('1999-03-14')), 'start of day') AS _expr51, + '2024-01-01' AS _expr50, + '1999-03-14' AS _expr51, DATETIME( STRFTIME('%Y-%m-%d %H:%M:00', DATE(DATETIME('now', '-60 hour'), 'start of day')), '196 year' diff --git a/tests/test_sql_refsols/smoke_b_ansi.sql b/tests/test_sql_refsols/smoke_b_ansi.sql index b4a57e048..ac499ec9f 100644 --- a/tests/test_sql_refsols/smoke_b_ansi.sql +++ b/tests/test_sql_refsols/smoke_b_ansi.sql @@ -33,19 +33,10 @@ SELECT 25, 'HOUR' ) AS d, - DATE_TRUNC('MINUTE', CAST('2025-01-01 12:35:13' AS TIMESTAMP)) AS e, - DATE_ADD( - DATE_ADD(DATE_TRUNC('HOUR', CAST('2025-01-01 12:35:13' AS TIMESTAMP)), 2, 'QUARTER'), - 3, - 'WEEK' - ) AS f, - CAST('2025-01-01 12:35:13' AS TIMESTAMP) AS g, - CONCAT_WS( - ';', - 12, - EXTRACT(MINUTE FROM CAST('2025-01-01 13:20:13' AS TIMESTAMP)), - EXTRACT(SECOND FROM CAST('2025-01-01 12:35:06' AS TIMESTAMP)) - ) AS h, + CAST('2025-01-01 12:35:00' AS TIMESTAMP) AS e, + CAST('2025-07-22 12:00:00' AS TIMESTAMP) AS f, + CAST('2025-01-01' AS DATE) AS g, + CONCAT_WS(';', 12, 20, 6) AS h, DATEDIFF(CAST(o_orderdate AS DATETIME), CAST('1993-05-25 12:45:36' AS TIMESTAMP), YEAR) AS i, DATEDIFF(CAST(o_orderdate AS DATETIME), CAST('1993-05-25 12:45:36' AS TIMESTAMP), QUARTER) AS j, DATEDIFF(CAST(o_orderdate AS DATETIME), CAST('1993-05-25 12:45:36' AS TIMESTAMP), MONTH) AS k, diff --git a/tests/test_sql_refsols/smoke_b_mysql.sql b/tests/test_sql_refsols/smoke_b_mysql.sql index cc2d69b28..d7e68338f 100644 --- a/tests/test_sql_refsols/smoke_b_mysql.sql +++ b/tests/test_sql_refsols/smoke_b_mysql.sql @@ -34,15 +34,10 @@ SELECT ) AS DATETIME), INTERVAL '25' HOUR ) AS d, - DATE(CAST('2025-01-01 12:35:13' AS DATETIME)) AS e, - CAST('2025-07-22' AS DATE) AS f, - CAST(CAST('2025-01-01 12:35:13' AS DATETIME) AS DATE) AS g, - CONCAT_WS( - ';', - 12, - MINUTE(CAST('2025-01-01 13:20:13' AS DATETIME)), - SECOND(CAST('2025-01-01 12:35:06' AS DATETIME)) - ) AS h, + CAST('2025-01-01 12:35:00' AS DATETIME) AS e, + CAST('2025-07-22 12:00:00' AS DATETIME) AS f, + CAST('2025-01-01' AS DATE) AS g, + CONCAT_WS(';', 12, 20, 6) AS h, YEAR(o_orderdate) - YEAR(CAST('1993-05-25 12:45:36' AS DATETIME)) AS i, ( YEAR(o_orderdate) - YEAR(CAST('1993-05-25 12:45:36' AS DATETIME)) diff --git a/tests/test_sql_refsols/smoke_b_sqlite.sql b/tests/test_sql_refsols/smoke_b_sqlite.sql index 6fbac82a0..a56aa99c8 100644 --- a/tests/test_sql_refsols/smoke_b_sqlite.sql +++ b/tests/test_sql_refsols/smoke_b_sqlite.sql @@ -54,15 +54,10 @@ SELECT ), '25 hour' ) AS d, - STRFTIME('%Y-%m-%d %H:%M:00', DATETIME('2025-01-01 12:35:13')) AS e, - DATETIME(STRFTIME('%Y-%m-%d %H:00:00', DATETIME('2025-01-01 12:35:13')), '6 month', '21 day') AS f, - DATE('2025-01-01 12:35:13', 'start of day') AS g, - CONCAT_WS( - ';', - 12, - CAST(STRFTIME('%M', DATETIME('2025-01-01 12:35:13', '45 minute')) AS INTEGER), - CAST(STRFTIME('%S', DATETIME('2025-01-01 12:35:13', '-7 second')) AS INTEGER) - ) AS h, + '2025-01-01 12:35:00' AS e, + '2025-07-22 12:00:00' AS f, + '2025-01-01' AS g, + CONCAT_WS(';', 12, 20, 6) AS h, CAST(STRFTIME('%Y', o_orderdate) AS INTEGER) - CAST(STRFTIME('%Y', DATETIME('1993-05-25 12:45:36')) AS INTEGER) AS i, ( CAST(STRFTIME('%Y', o_orderdate) AS INTEGER) - CAST(STRFTIME('%Y', DATETIME('1993-05-25 12:45:36')) AS INTEGER) From 81fcd718b7338bbe77baa941ec22ae21d3182797 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Tue, 26 Aug 2025 20:18:34 -0400 Subject: [PATCH 16/16] Adding in additional simplificaiton from other branch --- pydough/conversion/masking_shuttles.py | 5 +++-- pydough/conversion/relational_converter.py | 2 +- tests/test_plan_refsols/cryptbank_filter_count_08.txt | 2 +- tests/test_plan_refsols/cryptbank_filter_count_24.txt | 2 +- tests/test_plan_refsols/cryptbank_filter_count_25.txt | 2 +- tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql | 2 +- tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql | 2 +- tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql | 2 +- 8 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pydough/conversion/masking_shuttles.py b/pydough/conversion/masking_shuttles.py index 3dd2f403d..7586c85f2 100644 --- a/pydough/conversion/masking_shuttles.py +++ b/pydough/conversion/masking_shuttles.py @@ -8,6 +8,7 @@ from sqlglot import parse_one import pydough.pydough_operators as pydop +from pydough.configs import PyDoughConfigs from pydough.relational import ( CallExpression, LiteralExpression, @@ -24,8 +25,8 @@ class MaskLiteralComparisonShuttle(RelationalExpressionShuttle): TODO """ - def __init__(self): - self.simplifier: SimplificationShuttle = SimplificationShuttle() + def __init__(self, configs: PyDoughConfigs): + self.simplifier: SimplificationShuttle = SimplificationShuttle(configs) def simplify_masked_literal( self, value: RelationalExpression diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 957eabcbf..2ad4e312c 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -1589,7 +1589,7 @@ def convert_ast_to_relational( # Invoke the optimization procedures on the result to clean up the tree. additional_shuttles: list[RelationalExpressionShuttle] = [] if True: - additional_shuttles.append(MaskLiteralComparisonShuttle()) + additional_shuttles.append(MaskLiteralComparisonShuttle(configs)) optimized_result: RelationalRoot = optimize_relational_tree( raw_result, configs, additional_shuttles ) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_08.txt b/tests/test_plan_refsols/cryptbank_filter_count_08.txt index 845253916..649840cb8 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_08.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_08.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday == DATETIME('1985-04-12':string, '-472 days':string, 'start of day':string), columns={}) + FILTER(condition=c_birthday == datetime.date(1983, 12, 27):datetime, columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_24.txt b/tests/test_plan_refsols/cryptbank_filter_count_24.txt index ff82a8cc8..2a57465cb 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_24.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_24.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday == DATETIME('1991-11-15':string, '-472 days':string, 'start of day':string), columns={}) + FILTER(condition=c_birthday == datetime.date(1990, 7, 31):datetime, columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_plan_refsols/cryptbank_filter_count_25.txt b/tests/test_plan_refsols/cryptbank_filter_count_25.txt index 1c00e189d..ae826c87f 100644 --- a/tests/test_plan_refsols/cryptbank_filter_count_25.txt +++ b/tests/test_plan_refsols/cryptbank_filter_count_25.txt @@ -1,4 +1,4 @@ ROOT(columns=[('n', n)], orderings=[]) AGGREGATE(keys={}, aggregations={'n': COUNT()}) - FILTER(condition=c_birthday != DATETIME('1991-11-15':string, '-472 days':string, 'start of day':string), columns={}) + FILTER(condition=c_birthday != datetime.date(1990, 7, 31):datetime, columns={}) SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) diff --git a/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql index e45688058..f334ffdeb 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_08_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday = DATE(DATETIME('1985-04-12', '-472 day'), 'start of day') + c_birthday = '1983-12-27' diff --git a/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql index 3a3e438d7..199bc6091 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_24_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday = DATE(DATETIME('1991-11-15', '-472 day'), 'start of day') + c_birthday = '1990-07-31' diff --git a/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql b/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql index ad9993371..c26f72089 100644 --- a/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql +++ b/tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql @@ -2,4 +2,4 @@ SELECT COUNT(*) AS n FROM crbnk.customers WHERE - c_birthday <> DATE(DATETIME('1991-11-15', '-472 day'), 'start of day') + c_birthday <> '1990-07-31'