Skip to content

Commit e8e86e0

Browse files
authored
Add masked table column literal comperison masking rewrite (#418)
Followup to #417, performs the following rewrites on relational algebra expressions involving UNMASK operations: - `UNMASK(expr) == literal` -> `expr == MASK(literal)` - `UNMASK(expr) != literal` -> `expr != MASK(literal)` - `ISIN(UNMASK(expr), [literal1, literal2, ...])` -> `ISIN(expr, [MASK(literal1), MASK(literal2), ...])` These rewrites are done through an additional shuttle run during relational simplification, which is only activated if the environment variable `PYDOUGH_ENABLE_MASK_REWRITES` is set to 1.
1 parent e11abaa commit e8e86e0

File tree

52 files changed

+248
-82
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+248
-82
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Logic for replacing `UNMASK(x) == literal` (and similar expressions) with
3+
`x == MASK(literal)`.
4+
"""
5+
6+
__all__ = ["MaskLiteralComparisonShuttle"]
7+
8+
import pydough.pydough_operators as pydop
9+
from pydough.relational import (
10+
CallExpression,
11+
LiteralExpression,
12+
RelationalExpression,
13+
RelationalExpressionShuttle,
14+
)
15+
from pydough.types import ArrayType, PyDoughType, UnknownType
16+
17+
18+
class MaskLiteralComparisonShuttle(RelationalExpressionShuttle):
19+
"""
20+
A shuttle that recursively performs the following replacements:
21+
- `UNMASK(x) == literal` -> `x == MASK(literal)`
22+
- `literal == UNMASK(x)` -> `MASK(literal) == x`
23+
- `UNMASK(x) != literal` -> `x != MASK(literal)`
24+
- `literal != UNMASK(x)` -> `MASK(literal) != x`
25+
- `UNMASK(x) IN (literal1, ..., literalN)` -> `x IN (MASK(literal1), ..., MASK(literalN))`
26+
"""
27+
28+
def rewrite_masked_literal_comparison(
29+
self,
30+
original_call: CallExpression,
31+
call_arg: CallExpression,
32+
literal_arg: LiteralExpression,
33+
) -> CallExpression:
34+
"""
35+
Performs a rewrite of a comparison between a call to UNMASK and a
36+
literal, which is either equality, inequality, or containment.
37+
38+
Args:
39+
`original_call`: The original call expression representing the
40+
comparison.
41+
`call_arg`: The argument to the comparison that is a call to
42+
UNMASK, which is treated as the left-hand side of the comparison.
43+
`literal_arg`: The argument to the comparison that is a literal,
44+
which is treated as the right-hand side of the comparison.
45+
46+
Returns:
47+
A new call expression representing the rewritten comparison, or
48+
the original call expression if no rewrite was performed.
49+
"""
50+
51+
# Verify that the call argument is indeed an UNMASK operation, otherwise
52+
# fall back to the original.
53+
if (
54+
not isinstance(call_arg.op, pydop.MaskedExpressionFunctionOperator)
55+
or not call_arg.op.is_unmask
56+
):
57+
return original_call
58+
59+
masked_literal: RelationalExpression
60+
61+
if original_call.op in (pydop.EQU, pydop.NEQ):
62+
# If the operation is equality or inequality, we can simply wrap the
63+
# literal in a call to MASK by toggling is_unmask to False.
64+
masked_literal = CallExpression(
65+
pydop.MaskedExpressionFunctionOperator(
66+
call_arg.op.masking_metadata, False
67+
),
68+
call_arg.data_type,
69+
[literal_arg],
70+
)
71+
elif original_call.op == pydop.ISIN and isinstance(
72+
literal_arg.value, (list, tuple)
73+
):
74+
# If the operation is containment, and the literal is a list/tuple,
75+
# we need to build a list by wrapping each element of the tuple in
76+
# a MASK call.
77+
inner_type: PyDoughType
78+
if isinstance(literal_arg.data_type, ArrayType):
79+
inner_type = literal_arg.data_type.elem_type
80+
else:
81+
inner_type = UnknownType()
82+
masked_literal = LiteralExpression(
83+
[
84+
CallExpression(
85+
pydop.MaskedExpressionFunctionOperator(
86+
call_arg.op.masking_metadata, False
87+
),
88+
call_arg.data_type,
89+
[LiteralExpression(v, inner_type)],
90+
)
91+
for v in literal_arg.value
92+
],
93+
original_call.data_type,
94+
)
95+
else:
96+
# Otherwise, return the original.
97+
return original_call
98+
99+
# Now that we have the masked literal, we can return a new call
100+
# expression with the same operators as before, but where the left hand
101+
# side argument is the expression that was being unmasked, and the right
102+
# hand side is the masked literal.
103+
return CallExpression(
104+
original_call.op,
105+
original_call.data_type,
106+
[call_arg.inputs[0], masked_literal],
107+
)
108+
109+
def visit_call_expression(
110+
self, call_expression: CallExpression
111+
) -> RelationalExpression:
112+
# If the call expression is equality or inequality, dispatch to the
113+
# rewrite logic if one argument is a call expression and the other is
114+
# a literal.
115+
if call_expression.op in (pydop.EQU, pydop.NEQ):
116+
if isinstance(call_expression.inputs[0], CallExpression) and isinstance(
117+
call_expression.inputs[1], LiteralExpression
118+
):
119+
call_expression = self.rewrite_masked_literal_comparison(
120+
call_expression,
121+
call_expression.inputs[0],
122+
call_expression.inputs[1],
123+
)
124+
if isinstance(call_expression.inputs[1], CallExpression) and isinstance(
125+
call_expression.inputs[0], LiteralExpression
126+
):
127+
call_expression = self.rewrite_masked_literal_comparison(
128+
call_expression,
129+
call_expression.inputs[1],
130+
call_expression.inputs[0],
131+
)
132+
133+
# If the call expression is containment, dispatch to the rewrite logic
134+
# if the first argument is a call expression and the second is a
135+
# literal.
136+
if (
137+
call_expression.op == pydop.ISIN
138+
and isinstance(call_expression.inputs[0], CallExpression)
139+
and isinstance(call_expression.inputs[1], LiteralExpression)
140+
):
141+
call_expression = self.rewrite_masked_literal_comparison(
142+
call_expression, call_expression.inputs[0], call_expression.inputs[1]
143+
)
144+
145+
# Regardless of whether the rewrite occurred or not, invoke the regular
146+
# logic which will recursively transform the arguments.
147+
return super().visit_call_expression(call_expression)

pydough/conversion/relational_converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
__all__ = ["convert_ast_to_relational"]
77

88

9+
import os
910
from collections.abc import Iterable
1011
from dataclasses import dataclass
1112

@@ -84,6 +85,7 @@
8485
)
8586
from .hybrid_translator import HybridTranslator
8687
from .hybrid_tree import HybridTree
88+
from .masking_shuttles import MaskLiteralComparisonShuttle
8789
from .merge_projects import merge_projects
8890
from .projection_pullup import pullup_projections
8991
from .relational_simplification import simplify_expressions
@@ -1661,6 +1663,10 @@ def convert_ast_to_relational(
16611663

16621664
# Invoke the optimization procedures on the result to clean up the tree.
16631665
additional_shuttles: list[RelationalExpressionShuttle] = []
1666+
# Add the mask literal comparison shuttle if the environment variable
1667+
# PYDOUGH_ENABLE_MASK_REWRITES is set to 1.
1668+
if os.getenv("PYDOUGH_ENABLE_MASK_REWRITES") == "1":
1669+
additional_shuttles.append(MaskLiteralComparisonShuttle())
16641670
optimized_result: RelationalRoot = optimize_relational_tree(
16651671
raw_result, session, additional_shuttles
16661672
)

pydough/sqlglot/sqlglot_relational_expression_visitor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,22 @@ def visit_window_expression(self, window_expression: WindowCallExpression) -> No
310310
def visit_literal_expression(self, literal_expression: LiteralExpression) -> None:
311311
# Note: This assumes each literal has an associated type that can be parsed
312312
# and types do not represent implicit casts.
313-
literal: SQLGlotExpression = sqlglot_expressions.convert(
314-
literal_expression.value
315-
)
313+
literal: SQLGlotExpression
314+
if isinstance(literal_expression.value, (tuple, list)):
315+
# If the literal is a list or tuple, convert each element
316+
# individually and create an array literal.
317+
elements: list[SQLGlotExpression] = []
318+
for element in literal_expression.value:
319+
element_expr: SQLGlotExpression
320+
if isinstance(element, RelationalExpression):
321+
element.accept(self)
322+
element_expr = self._stack.pop()
323+
else:
324+
element_expr = sqlglot_expressions.convert(element)
325+
elements.append(element_expr)
326+
literal = sqlglot_expressions.Array(expressions=elements)
327+
else:
328+
literal = sqlglot_expressions.convert(literal_expression.value)
316329

317330
# Special handling: insert cast calls for ansi casting of date/time
318331
# instead of relying on SQLGlot conversion functions. This is because

tests/test_metadata/sf_masked_examples.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@
218218
"data type": "string",
219219
"server masked": true,
220220
"unprotect protocol": "PTY_UNPROTECT({}, 'deName')",
221-
"protect protocol": "PTY_PROTECT({}, 'deName)",
221+
"protect protocol": "PTY_PROTECT({}, 'deName')",
222222
"description": "The first name of the customer",
223223
"sample values": ["Julie", "Melissa", "Gary"],
224224
"synonyms": ["customer first name", "given name"]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ROOT(columns=[('n', n_rows)], orderings=[])
22
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
3-
FILTER(condition=UNMASK::(LOWER([c_lname])) == 'lee':string, columns={})
3+
FILTER(condition=c_lname == MASK::(UPPER(['lee':string])), columns={})
44
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ROOT(columns=[('n', n_rows)], orderings=[])
22
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
3-
FILTER(condition=UNMASK::(LOWER([c_lname])) != 'lee':string, columns={})
3+
FILTER(condition=c_lname != MASK::(UPPER(['lee':string])), columns={})
44
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ROOT(columns=[('n', n_rows)], orderings=[])
22
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
3-
FILTER(condition=ISIN(UNMASK::(LOWER([c_lname])), ['lee', 'smith', 'rodriguez']:array[unknown]), columns={})
3+
FILTER(condition=ISIN(c_lname, [Call(op=MASK, inputs=[Literal(value='lee', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='smith', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='rodriguez', type=UnknownType())], return_type=StringType())]:bool), columns={})
44
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ROOT(columns=[('n', n_rows)], orderings=[])
22
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
3-
FILTER(condition=NOT(ISIN(UNMASK::(LOWER([c_lname])), ['lee', 'smith', 'rodriguez']:array[unknown])), columns={})
3+
FILTER(condition=NOT(ISIN(c_lname, [Call(op=MASK, inputs=[Literal(value='lee', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='smith', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='rodriguez', type=UnknownType())], return_type=StringType())]:bool)), columns={})
44
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname})
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ROOT(columns=[('n', n_rows)], orderings=[])
22
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
3-
FILTER(condition=UNMASK::(DATE([c_birthday], '+472 days')) == '1985-04-12':string, columns={})
3+
FILTER(condition=c_birthday == MASK::(DATE(['1985-04-12':string], '-472 days')), columns={})
44
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday})

tests/test_plan_refsols/cryptbank_filter_count_11_rewrite.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ ROOT(columns=[('n', n_rows)], orderings=[])
44
SCAN(table=CRBNK.TRANSACTIONS, columns={'t_sourceaccount': t_sourceaccount})
55
JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'a_key': t0.a_key})
66
SCAN(table=CRBNK.ACCOUNTS, columns={'a_custkey': a_custkey, 'a_key': a_key})
7-
FILTER(condition=UNMASK::(LOWER([c_fname])) == 'alice':string, columns={'c_key': c_key})
7+
FILTER(condition=c_fname == MASK::(UPPER(['alice':string])), columns={'c_key': c_key})
88
SCAN(table=CRBNK.CUSTOMERS, columns={'c_fname': c_fname, 'c_key': c_key})

0 commit comments

Comments
 (0)