Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b48d3f5
Initial quarter simplification
knassre-bodo Aug 14, 2025
948bf6f
Added initial simplification tests for new patterns
knassre-bodo Aug 15, 2025
fdc92ae
Added datetime literal extraction simplification and tests
knassre-bodo Aug 15, 2025
c067a00
Added month/day/hour/minute/second edge comparison tests
knassre-bodo Aug 15, 2025
8671cb6
Added relational filter/join not-null inferrence and null-literal pro…
knassre-bodo Aug 15, 2025
0a5402f
Fixing KEEP_IF bug [RUN CI]
knassre-bodo Aug 15, 2025
41373f1
Adding documentation
knassre-bodo Aug 15, 2025
22bc152
Resolving conflicts [RUN CI]
knassre-bodo Aug 21, 2025
19ac326
Update pydough/conversion/relational_simplification.py
knassre-bodo Aug 22, 2025
3b98ae2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2025
7a6ed03
Adding DOW/DAYNAME simplification [RUN CI] [RUN MYSQL]
knassre-bodo Aug 25, 2025
436927d
Merge remote-tracking branch 'origin/kian/more_simp' into kian/more_simp
knassre-bodo Aug 25, 2025
4bf584c
Adding DOW/DAYNAME simplification [RUN CI] [RUN MYSQL]
knassre-bodo Aug 25, 2025
66a19e9
Add logic for simplifying literals that are masked
knassre-bodo Aug 26, 2025
df8e741
Initial implementation of mask simplification by converting sqlglot b…
knassre-bodo Aug 26, 2025
71236f1
Datetime manipulation WIP
knassre-bodo Aug 26, 2025
e6251e2
Added DATETIME chain simplification
knassre-bodo Aug 27, 2025
927c797
Resolving conflicts [RUN ALL]
knassre-bodo Aug 27, 2025
82721d0
Merge branch 'kian/mask_literal_rewrite' into kian/mask_literal_simp
knassre-bodo Aug 27, 2025
cfb951b
Merge branch 'kian/more_simp' into kian/mask_literal_simp
knassre-bodo Aug 27, 2025
81fcd71
Adding in additional simplificaiton from other branch
knassre-bodo Aug 27, 2025
2500aa3
Resolving conflicts
knassre-bodo Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions pydough/conversion/masking_shuttles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,59 @@

__all__ = ["MaskLiteralComparisonShuttle"]

from sqlglot import expressions as sqlglot_expressions
from sqlglot import parse_one

import pydough.pydough_operators as pydop
from pydough.configs import PyDoughConfigs
from pydough.relational import (
CallExpression,
LiteralExpression,
RelationalExpression,
RelationalExpressionShuttle,
)
from pydough.sqlglot import convert_sqlglot_to_relational

from .relational_simplification import SimplificationShuttle


class MaskLiteralComparisonShuttle(RelationalExpressionShuttle):
"""
TODO
"""

def is_unprotect_call(self, expr: RelationalExpression) -> bool:
def __init__(self, configs: PyDoughConfigs):
self.simplifier: SimplificationShuttle = SimplificationShuttle(configs)

def simplify_masked_literal(
self, value: RelationalExpression
) -> RelationalExpression:
"""
TODO
"""
return (
isinstance(expr, CallExpression)
and isinstance(expr.op, pydop.MaskedExpressionFunctionOperator)
and expr.op.is_unprotect
)
if (
not isinstance(value, CallExpression)
or not isinstance(value.op, pydop.MaskedExpressionFunctionOperator)
or value.op.is_unprotect
or len(value.inputs) != 1
or not isinstance(value.inputs[0], LiteralExpression)
):
return value
try:
arg_sql_str: str = sqlglot_expressions.convert(value.inputs[0].value).sql()
total_sql_str: str = value.op.format_string.format(arg_sql_str)
glot_expr: sqlglot_expressions.Expression = parse_one(total_sql_str)
new_expr: RelationalExpression | None = convert_sqlglot_to_relational(
glot_expr
)
if new_expr is not None:
return new_expr
self.simplifier.reset()
return value.accept_shuttle(self.simplifier)
except Exception:
return value

return value

def protect_literal_comparison(
self,
Expand All @@ -53,22 +83,23 @@ def protect_literal_comparison(
call_arg.data_type,
[literal_arg],
)
masked_literal = self.simplify_masked_literal(masked_literal)
elif original_call.op == pydop.ISIN and isinstance(
literal_arg.value, (list, tuple)
):
masked_literal = LiteralExpression(
[
CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
),
call_arg.data_type,
[LiteralExpression(v, literal_arg.data_type)],
)
for v in literal_arg.value
],
original_call.data_type,
)
new_elems: list[RelationalExpression] = []
for val in literal_arg.value:
new_val: RelationalExpression = CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
),
call_arg.data_type,
[LiteralExpression(val, literal_arg.data_type)],
)
new_val = self.simplify_masked_literal(new_val)
new_elems.append(new_val)

masked_literal = LiteralExpression(new_elems, original_call.data_type)
else:
return original_call

Expand All @@ -82,6 +113,8 @@ def visit_call_expression(
self, call_expression: CallExpression
) -> RelationalExpression:
if call_expression.op in (pydop.EQU, pydop.NEQ):
# UNMASK(expr) = literal --> expr = MASK(literal)
# UNMASK(expr) != literal --> expr != MASK(literal)
if isinstance(call_expression.inputs[0], CallExpression) and isinstance(
call_expression.inputs[1], LiteralExpression
):
Expand All @@ -90,6 +123,8 @@ def visit_call_expression(
call_expression.inputs[0],
call_expression.inputs[1],
)
# literal = UNMASK(expr) --> MASK(literal) = expr
# literal != UNMASK(expr) --> MASK(literal) != expr
if isinstance(call_expression.inputs[1], CallExpression) and isinstance(
call_expression.inputs[0], LiteralExpression
):
Expand All @@ -98,6 +133,7 @@ def visit_call_expression(
call_expression.inputs[1],
call_expression.inputs[0],
)
# UNMASK(expr) IN (x, y, z) --> expr IN (MASK(x), MASK(y), MASK(z))
if (
call_expression.op == pydop.ISIN
and isinstance(call_expression.inputs[0], CallExpression)
Expand Down
2 changes: 1 addition & 1 deletion pydough/conversion/relational_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,7 @@ def convert_ast_to_relational(
# Invoke the optimization procedures on the result to clean up the tree.
additional_shuttles: list[RelationalExpressionShuttle] = []
if True:
additional_shuttles.append(MaskLiteralComparisonShuttle())
additional_shuttles.append(MaskLiteralComparisonShuttle(configs))
optimized_result: RelationalRoot = optimize_relational_tree(
raw_result, configs, additional_shuttles
)
Expand Down
12 changes: 12 additions & 0 deletions pydough/conversion/relational_simplification.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def visit_literal_expression(
output_predicates.not_negative = True
if literal_expression.value > 0:
output_predicates.positive = True
if isinstance(literal_expression.value, (list, tuple)):
new_elems: list = []
for val in literal_expression.value:
if isinstance(val, RelationalExpression):
new_elems.append(val.accept_shuttle(self))
self.stack.pop()
else:
new_elems.append(val)
if new_elems != literal_expression.value:
literal_expression = LiteralExpression(
new_elems, literal_expression.data_type
)
self.stack.append(output_predicates)
return literal_expression

Expand Down
2 changes: 2 additions & 0 deletions pydough/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"SQLGlotRelationalVisitor",
"convert_dialect_to_sqlglot",
"convert_relation_to_sql",
"convert_sqlglot_to_relational",
"execute_df",
"find_identifiers",
"find_identifiers_in_list",
Expand All @@ -23,3 +24,4 @@
from .sqlglot_identifier_finder import find_identifiers, find_identifiers_in_list
from .sqlglot_relational_expression_visitor import SQLGlotRelationalExpressionVisitor
from .sqlglot_relational_visitor import SQLGlotRelationalVisitor
from .sqlglot_to_relational import convert_sqlglot_to_relational
98 changes: 98 additions & 0 deletions pydough/sqlglot/sqlglot_to_relational.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
TODO
"""

__all__ = ["convert_sqlglot_to_relational"]


from sqlglot import expressions as sqlglot_expressions

import pydough.pydough_operators as pydop
from pydough.relational import (
CallExpression,
LiteralExpression,
RelationalExpression,
)
from pydough.types import (
BooleanType,
DatetimeType,
NumericType,
StringType,
UnknownType,
)


class GlotRelFail(Exception):
"""
TODO
"""


def glot_to_rel(glot_expr: sqlglot_expressions.Expression) -> RelationalExpression:
"""
TODO
"""

# Convert all of the sub-expressions to a relational expression. This
# step is stored in a "thunk" so it only happens under certain conditions.
def sub_rels() -> list[RelationalExpression]:
return [glot_to_rel(e) for e in glot_expr.iter_expressions()]

flushed_args: list[RelationalExpression]
match glot_expr:
case sqlglot_expressions.TimeStrToTime():
return CallExpression(
pydop.DATETIME,
DatetimeType(),
sub_rels(),
)
case sqlglot_expressions.Literal():
literal_expr = glot_expr.to_py()
if isinstance(literal_expr, str):
return LiteralExpression(literal_expr, StringType())
elif isinstance(literal_expr, bool):
return LiteralExpression(literal_expr, BooleanType())
elif isinstance(literal_expr, int) or isinstance(literal_expr, float):
return LiteralExpression(literal_expr, NumericType())
else:
raise GlotRelFail()
return LiteralExpression(
glot_expr.this,
StringType(),
)
case sqlglot_expressions.Null():
return LiteralExpression(None, UnknownType())
case sqlglot_expressions.Lower():
return CallExpression(pydop.LOWER, StringType(), sub_rels())
case sqlglot_expressions.Upper():
return CallExpression(pydop.UPPER, StringType(), sub_rels())
case sqlglot_expressions.Length():
return CallExpression(pydop.LENGTH, NumericType(), sub_rels())
case sqlglot_expressions.Abs():
return CallExpression(pydop.ABS, StringType(), sub_rels())
case sqlglot_expressions.Datetime():
return CallExpression(pydop.DATETIME, DatetimeType(), sub_rels())
case sqlglot_expressions.Date():
flushed_args = sub_rels()
flushed_args.append(LiteralExpression("start of day", StringType()))
return CallExpression(pydop.DATETIME, DatetimeType(), flushed_args)
case _:
raise GlotRelFail()


def convert_sqlglot_to_relational(
glot_expr: sqlglot_expressions.Expression,
) -> RelationalExpression | None:
"""
Attempt to convert a sqlglot expression to a relational expression.

Args:
`glot_expr`: The sqlglot expression to convert.

Returns:
The converted relational expression, or None if the attempt failed.
"""
try:
return glot_to_rel(glot_expr)
except GlotRelFail:
return None
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_01.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=c_lname == MASK::(UPPER(['lee':string])), columns={})
FILTER(condition=c_lname == 'LEE':string, columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_02.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=c_lname != MASK::(UPPER(['lee':string])), columns={})
FILTER(condition=c_lname != 'LEE':string, columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_03.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=ISIN(c_lname, [Call(op=Function[MASK], inputs=[Literal(value='lee', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='smith', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='rodriguez', type=ArrayType(UnknownType()))], return_type=StringType())]:bool), columns={})
FILTER(condition=ISIN(c_lname, [Literal(value='LEE', type=StringType()), Literal(value='SMITH', type=StringType()), Literal(value='RODRIGUEZ', type=StringType())]:bool), columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_04.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=NOT(ISIN(c_lname, [Call(op=Function[MASK], inputs=[Literal(value='lee', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='smith', type=ArrayType(UnknownType()))], return_type=StringType()), Call(op=Function[MASK], inputs=[Literal(value='rodriguez', type=ArrayType(UnknownType()))], return_type=StringType())]:bool)), columns={})
FILTER(condition=NOT(ISIN(c_lname, [Literal(value='LEE', type=StringType()), Literal(value='SMITH', type=StringType()), Literal(value='RODRIGUEZ', type=StringType())]:bool)), columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_08.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=c_birthday == MASK::(DATE(['1985-04-12':string], '-472 days')), columns={})
FILTER(condition=c_birthday == datetime.date(1983, 12, 27):datetime, columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ ROOT(columns=[('n', n_rows)], orderings=[])
SCAN(table=CRBNK.TRANSACTIONS, columns={'t_sourceaccount': t_sourceaccount})
JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_FILTER, columns={'a_key': t0.a_key})
SCAN(table=CRBNK.ACCOUNTS, columns={'a_custkey': a_custkey, 'a_key': a_key})
FILTER(condition=c_fname == MASK::(UPPER(['alice':string])), columns={'c_key': c_key})
FILTER(condition=c_fname == 'ALICE':string, columns={'c_key': c_key})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_fname': c_fname, 'c_key': c_key})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_24.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=c_birthday == MASK::(DATE(['1991-11-15':string], '-472 days')), columns={})
FILTER(condition=c_birthday == datetime.date(1990, 7, 31):datetime, columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_25.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) | c_birthday != MASK::(DATE(['1991-11-15':string], '-472 days')), columns={})
FILTER(condition=ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) | c_birthday != datetime.date(1990, 7, 31):datetime, columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday})
2 changes: 1 addition & 1 deletion tests/test_plan_refsols/cryptbank_filter_count_27.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ROOT(columns=[('n', n_rows)], orderings=[])
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
FILTER(condition=PRESENT(UNMASK::(SUBSTRING([c_addr], -1) || SUBSTRING([c_addr], 1, LENGTH([c_addr]) - 1))) & PRESENT(UNMASK::(DATE([c_birthday], '+472 days'))) & c_lname != MASK::(UPPER(['lopez':string])) & ENDSWITH(UNMASK::(LOWER([c_fname])), 'a':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 'e':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 's':string) | ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) & ENDSWITH(UNMASK::(REPLACE(REPLACE(REPLACE([c_phone], '9', '*'), '0', '9'), '*', '0')), '5':string), columns={})
FILTER(condition=PRESENT(UNMASK::(SUBSTRING([c_addr], -1) || SUBSTRING([c_addr], 1, LENGTH([c_addr]) - 1))) & PRESENT(UNMASK::(DATE([c_birthday], '+472 days'))) & c_lname != 'LOPEZ':string & ENDSWITH(UNMASK::(LOWER([c_fname])), 'a':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 'e':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 's':string) | ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) & ENDSWITH(UNMASK::(REPLACE(REPLACE(REPLACE([c_phone], '9', '*'), '0', '9'), '*', '0')), '5':string), columns={})
SCAN(table=CRBNK.CUSTOMERS, columns={'c_addr': c_addr, 'c_birthday': c_birthday, 'c_fname': c_fname, 'c_lname': c_lname, 'c_phone': c_phone})
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
c_lname = UPPER('lee')
c_lname = 'LEE'
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
c_lname <> UPPER('lee')
c_lname <> 'LEE'
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
c_lname IN (UPPER('lee'), UPPER('smith'), UPPER('rodriguez'))
c_lname IN ('LEE', 'SMITH', 'RODRIGUEZ')
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
NOT c_lname IN (UPPER('lee'), UPPER('smith'), UPPER('rodriguez'))
NOT c_lname IN ('LEE', 'SMITH', 'RODRIGUEZ')
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
c_birthday = DATE('1985-04-12', '-472 days')
c_birthday = '1983-12-27'
3 changes: 1 addition & 2 deletions tests/test_sql_refsols/cryptbank_filter_count_11_sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@ JOIN crbnk.accounts AS accounts
JOIN crbnk.customers AS customers
ON accounts.a_custkey = (
42 - customers.c_key
)
AND customers.c_fname = UPPER('alice')
) AND customers.c_fname = 'ALICE'
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
c_birthday = DATE('1991-11-15', '-472 days')
c_birthday = '1990-07-31'
3 changes: 1 addition & 2 deletions tests/test_sql_refsols/cryptbank_filter_count_25_sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ SELECT
COUNT(*) AS n
FROM crbnk.customers
WHERE
DATE(c_birthday, '+472 days') IS NULL
OR c_birthday <> DATE('1991-11-15', '-472 days')
DATE(c_birthday, '+472 days') IS NULL OR c_birthday <> '1990-07-31'
4 changes: 2 additions & 2 deletions tests/test_sql_refsols/cryptbank_filter_count_27_sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ WHERE
) IS NULL
)
AND (
DATE(c_birthday, '+472 days') IS NULL OR c_lname <> UPPER('lopez')
DATE(c_birthday, '+472 days') IS NULL OR c_lname <> 'LOPEZ'
)
AND (
LOWER(c_fname) LIKE '%a'
Expand All @@ -35,5 +35,5 @@ WHERE
)
AND (
REPLACE(REPLACE(REPLACE(c_phone, '9', '*'), '0', '9'), '*', '0') LIKE '%5'
OR c_lname <> UPPER('lopez')
OR c_lname <> 'LOPEZ'
)