-
Notifications
You must be signed in to change notification settings - Fork 3
Add masked table column literal comperison masking rewrite #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 66 commits
83aaa52
3930b3d
4a02e2a
32e678e
957f010
2596a00
f2e46c4
86036b7
f896f14
c1315e3
48b892b
040d725
16de5a3
5241003
2a27c99
9e50717
c88d2f0
442621e
1fe90af
e5b8ab8
2313b57
ffbe3fe
d352175
aa2ee68
f09d0e7
76a16c2
be00e58
fa2d869
fed46d5
655054f
874dcad
2189fb3
28b58ce
4536992
e2fe6b7
0f6e59f
98a9c4c
46b1c36
bf2b075
a883759
5d273c3
2d69928
bc09e3f
ab08ce4
a36fb2b
df477e7
cccbe19
61194bb
1dfe201
600492a
82e9691
2a24514
aea501f
a733022
a5c3b9c
bdde458
4a58775
66f2193
630c7cc
f4c318f
c8ade74
857c39e
89fa4a6
ef3ae04
85376d7
4bf43ea
db3ba6f
1b601ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
""" | ||
Logic for replacing `UNMASK(x) == literal` (and similar expressions) with | ||
`x == MASK(literal)`. | ||
""" | ||
|
||
__all__ = ["MaskLiteralComparisonShuttle"] | ||
|
||
import pydough.pydough_operators as pydop | ||
from pydough.relational import ( | ||
CallExpression, | ||
LiteralExpression, | ||
RelationalExpression, | ||
RelationalExpressionShuttle, | ||
) | ||
from pydough.types import ArrayType, PyDoughType, UnknownType | ||
|
||
|
||
class MaskLiteralComparisonShuttle(RelationalExpressionShuttle): | ||
""" | ||
A shuttle that recursively performs the following replacements: | ||
- `UNMASK(x) == literal` -> `x == MASK(literal)` | ||
- `literal == UNMASK(x)` -> `MASK(literal) == x` | ||
- `UNMASK(x) != literal` -> `x != MASK(literal)` | ||
- `literal != UNMASK(x)` -> `MASK(literal) != x` | ||
- `UNMASK(x) IN (literal1, ..., literalN)` -> `x IN (MASK(literal1), ..., MASK(literalN))` | ||
""" | ||
|
||
def protect_literal_comparison( | ||
knassre-bodo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self, | ||
original_call: CallExpression, | ||
call_arg: CallExpression, | ||
literal_arg: LiteralExpression, | ||
) -> CallExpression: | ||
""" | ||
Performs a rewrite of a comparison between a call to UNMASK and a | ||
literal, which is either equality, inequality, or containment. | ||
Args: | ||
`original_call`: The original call expression representing the | ||
comparison. | ||
`call_arg`: The argument to the comparison that is a call to | ||
UNMASK, which is treated as the left-hand side of the comparison. | ||
`literal_arg`: The argument to the comparison that is a literal, | ||
which is treated as the right-hand side of the comparison. | ||
Returns: | ||
A new call expression representing the rewritten comparison, or | ||
the original call expression if no rewrite was performed. | ||
""" | ||
|
||
# Verify that the call argument is indeed an UNMASK operation, otherwise | ||
# fall back to the original. | ||
if ( | ||
not isinstance(call_arg.op, pydop.MaskedExpressionFunctionOperator) | ||
or not call_arg.op.is_unmask | ||
): | ||
return original_call | ||
|
||
masked_literal: RelationalExpression | ||
|
||
if original_call.op in (pydop.EQU, pydop.NEQ): | ||
# If the operation is equality or inequality, we can simply wrap the | ||
# literal in a call to MASK by toggling is_unmask to False. | ||
masked_literal = CallExpression( | ||
pydop.MaskedExpressionFunctionOperator( | ||
call_arg.op.masking_metadata, False | ||
), | ||
call_arg.data_type, | ||
[literal_arg], | ||
) | ||
elif original_call.op == pydop.ISIN and isinstance( | ||
literal_arg.value, (list, tuple) | ||
): | ||
# If the operation is containment, and the literal is a list/tuple, | ||
# we need to build a list by wrapping each element of the tuple in | ||
# a MASK call. | ||
inner_type: PyDoughType | ||
if isinstance(literal_arg.data_type, ArrayType): | ||
inner_type = literal_arg.data_type.elem_type | ||
else: | ||
inner_type = UnknownType() | ||
masked_literal = LiteralExpression( | ||
[ | ||
CallExpression( | ||
pydop.MaskedExpressionFunctionOperator( | ||
call_arg.op.masking_metadata, False | ||
), | ||
call_arg.data_type, | ||
[LiteralExpression(v, inner_type)], | ||
) | ||
for v in literal_arg.value | ||
], | ||
original_call.data_type, | ||
) | ||
else: | ||
# Otherwise, return the original. | ||
return original_call | ||
|
||
# Now that we have the masked literal, we can return a new call | ||
# expression with the same operators as before, but where the left hand | ||
# side argument is the expression that was being unmasked, and the right | ||
# hand side is the masked literal. | ||
return CallExpression( | ||
original_call.op, | ||
original_call.data_type, | ||
[call_arg.inputs[0], masked_literal], | ||
) | ||
|
||
def visit_call_expression( | ||
self, call_expression: CallExpression | ||
) -> RelationalExpression: | ||
# If the call expression is equality or inequality, dispatch to the | ||
# rewrite logic if one argument is a call expression and the other is | ||
# a literal. | ||
if call_expression.op in (pydop.EQU, pydop.NEQ): | ||
if isinstance(call_expression.inputs[0], CallExpression) and isinstance( | ||
call_expression.inputs[1], LiteralExpression | ||
): | ||
call_expression = self.protect_literal_comparison( | ||
call_expression, | ||
call_expression.inputs[0], | ||
call_expression.inputs[1], | ||
) | ||
if isinstance(call_expression.inputs[1], CallExpression) and isinstance( | ||
call_expression.inputs[0], LiteralExpression | ||
): | ||
call_expression = self.protect_literal_comparison( | ||
call_expression, | ||
call_expression.inputs[1], | ||
call_expression.inputs[0], | ||
) | ||
|
||
# If the call expression is containment, dispatch to the rewrite logic | ||
# if the first argument is a call expression and the second is a | ||
# literal. | ||
if ( | ||
call_expression.op == pydop.ISIN | ||
and isinstance(call_expression.inputs[0], CallExpression) | ||
and isinstance(call_expression.inputs[1], LiteralExpression) | ||
): | ||
call_expression = self.protect_literal_comparison( | ||
call_expression, call_expression.inputs[0], call_expression.inputs[1] | ||
) | ||
|
||
# Regardless of whether the rewrite occurred or not, invoke the regular | ||
# logic which will recursively transform the arguments. | ||
return super().visit_call_expression(call_expression) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -312,9 +312,22 @@ def visit_window_expression(self, window_expression: WindowCallExpression) -> No | |
def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: | ||
# Note: This assumes each literal has an associated type that can be parsed | ||
# and types do not represent implicit casts. | ||
literal: SQLGlotExpression = sqlglot_expressions.convert( | ||
literal_expression.value | ||
) | ||
literal: SQLGlotExpression | ||
if isinstance(literal_expression.value, (tuple, list)): | ||
# If the literal is a list or tuple, convert each element | ||
# individually and create an array literal. | ||
Comment on lines
+317
to
+318
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed because now, with the recent changes, a "list literal" can contain non-literal expressions, e.g. a list of function calls |
||
elements: list[SQLGlotExpression] = [] | ||
for element in literal_expression.value: | ||
element_expr: SQLGlotExpression | ||
if isinstance(element, RelationalExpression): | ||
element.accept(self) | ||
element_expr = self._stack.pop() | ||
else: | ||
element_expr = sqlglot_expressions.convert(element) | ||
elements.append(element_expr) | ||
literal = sqlglot_expressions.Array(expressions=elements) | ||
else: | ||
literal = sqlglot_expressions.convert(literal_expression.value) | ||
|
||
# Special handling: insert cast calls for ansi casting of date/time | ||
# instead of relying on SQLGlot conversion functions. This is because | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,7 +218,7 @@ | |
"data type": "string", | ||
"server masked": true, | ||
"unprotect protocol": "PTY_UNPROTECT({}, 'deName')", | ||
"protect protocol": "PTY_PROTECT({}, 'deName)", | ||
"protect protocol": "PTY_PROTECT({}, 'deName')", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed earlier, cuases a bug in this PR because |
||
"description": "The first name of the customer", | ||
"sample values": ["Julie", "Melissa", "Gary"], | ||
"synonyms": ["customer first name", "given name"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=UNMASK::(LOWER([c_lname])) == 'lee':string, columns={}) | ||
FILTER(condition=c_lname == MASK::(UPPER(['lee':string])), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=UNMASK::(LOWER([c_lname])) != 'lee':string, columns={}) | ||
FILTER(condition=c_lname != MASK::(UPPER(['lee':string])), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=ISIN(UNMASK::(LOWER([c_lname])), ['lee', 'smith', 'rodriguez']:array[unknown]), columns={}) | ||
FILTER(condition=ISIN(c_lname, [Call(op=MASK, inputs=[Literal(value='lee', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='smith', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='rodriguez', type=UnknownType())], return_type=StringType())]:bool), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=NOT(ISIN(UNMASK::(LOWER([c_lname])), ['lee', 'smith', 'rodriguez']:array[unknown])), columns={}) | ||
FILTER(condition=NOT(ISIN(c_lname, [Call(op=MASK, inputs=[Literal(value='lee', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='smith', type=UnknownType())], return_type=StringType()), Call(op=MASK, inputs=[Literal(value='rodriguez', type=UnknownType())], return_type=StringType())]:bool)), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=UNMASK::(DATE([c_birthday], '+472 days')) == '1985-04-12':string, columns={}) | ||
FILTER(condition=c_birthday == MASK::(DATE(['1985-04-12':string], '-472 days')), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=UNMASK::(DATE([c_birthday], '+472 days')) == '1991-11-15':string, columns={}) | ||
FILTER(condition=c_birthday == MASK::(DATE(['1991-11-15':string], '-472 days')), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) | UNMASK::(DATE([c_birthday], '+472 days')) != '1991-11-15':string, columns={}) | ||
FILTER(condition=ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) | c_birthday != MASK::(DATE(['1991-11-15':string], '-472 days')), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_birthday': c_birthday}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=UNMASK::(REPLACE(REPLACE(REPLACE([c_phone], '9', '*'), '0', '9'), '*', '0')) == '555-123-456':string, columns={}) | ||
FILTER(condition=c_phone == MASK::(REPLACE(REPLACE(REPLACE(['555-123-456':string], '0', '*'), '9', '0'), '*', '9')), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_phone': c_phone}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
FILTER(condition=PRESENT(UNMASK::(SUBSTRING([c_addr], -1) || SUBSTRING([c_addr], 1, LENGTH([c_addr]) - 1))) & PRESENT(UNMASK::(DATE([c_birthday], '+472 days'))) & UNMASK::(LOWER([c_lname])) != 'lopez':string & ENDSWITH(UNMASK::(LOWER([c_fname])), 'a':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 'e':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 's':string) | ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) & ENDSWITH(UNMASK::(REPLACE(REPLACE(REPLACE([c_phone], '9', '*'), '0', '9'), '*', '0')), '5':string), columns={}) | ||
FILTER(condition=PRESENT(UNMASK::(SUBSTRING([c_addr], -1) || SUBSTRING([c_addr], 1, LENGTH([c_addr]) - 1))) & PRESENT(UNMASK::(DATE([c_birthday], '+472 days'))) & c_lname != MASK::(UPPER(['lopez':string])) & ENDSWITH(UNMASK::(LOWER([c_fname])), 'a':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 'e':string) | ENDSWITH(UNMASK::(LOWER([c_fname])), 's':string) | ABSENT(UNMASK::(DATE([c_birthday], '+472 days'))) & ENDSWITH(UNMASK::(REPLACE(REPLACE(REPLACE([c_phone], '9', '*'), '0', '9'), '*', '0')), '5':string), columns={}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_addr': c_addr, 'c_birthday': c_birthday, 'c_fname': c_fname, 'c_lname': c_lname, 'c_phone': c_phone}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_FILTER, columns={}) | ||
FILTER(condition=YEAR(UNMASK::(DATETIME([a_open_ts], '+123456789 seconds'))) < 2020:numeric & UNMASK::(SQRT([a_balance])) >= 5000:numeric & UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1)) == 'retirement':string | UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1)) == 'savings':string, columns={'a_custkey': a_custkey}) | ||
FILTER(condition=YEAR(UNMASK::(DATETIME([a_open_ts], '+123456789 seconds'))) < 2020:numeric & UNMASK::(SQRT([a_balance])) >= 5000:numeric & a_type == MASK::(SUBSTRING(['retirement':string], 2) || SUBSTRING(['retirement':string], 1, 1)) | a_type == MASK::(SUBSTRING(['savings':string], 2) || SUBSTRING(['savings':string], 1, 1)), columns={'a_custkey': a_custkey}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_custkey': a_custkey, 'a_open_ts': a_open_ts, 'a_type': a_type}) | ||
FILTER(condition=CONTAINS(UNMASK::(SUBSTRING([c_email], -1) || SUBSTRING([c_email], 1, LENGTH([c_email]) - 1)), 'outlook':string) | CONTAINS(UNMASK::(SUBSTRING([c_email], -1) || SUBSTRING([c_email], 1, LENGTH([c_email]) - 1)), 'gmail':string), columns={'c_key': c_key}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_email': c_email, 'c_key': c_key}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
ROOT(columns=[('n', n_rows)], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) | ||
JOIN(condition=UNMASK::(PTY_UNPROTECT_ACCOUNT([t0.customerid])) == UNMASK::(PTY_UNPROTECT([t1.customerid], 'deAccount')), type=INNER, cardinality=SINGULAR_FILTER, columns={}) | ||
FILTER(condition=UNMASK::(PTY_UNPROTECT_ACCOUNT([currency])) != 'GBP':string & balance < 20000:numeric, columns={'customerid': customerid}) | ||
FILTER(condition=currency != MASK::(PTY_PROTECT(['GBP':string], 'deAccount')) & balance < 20000:numeric, columns={'customerid': customerid}) | ||
SCAN(table=bodo.fsi.accounts, columns={'balance': balance, 'currency': currency, 'customerid': customerid}) | ||
FILTER(condition=UNMASK::(PTY_UNPROTECT([state], 'deAddress')) == 'California':string, columns={'customerid': customerid}) | ||
FILTER(condition=state == MASK::(PTY_PROTECT(['California':string], 'deAddress')), columns={'customerid': customerid}) | ||
SCAN(table=bodo.fsi.protected_customers, columns={'customerid': customerid, 'state': state}) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Questions:
1- What would happen if the literal value is
None
?2- Is it possible to have nested calls?
UNMASK(UNMASK(x)) == literal
. If yes, then how would that be handled?3- Why only equality/inequality? What about other comparison operations (>, >=, <, <=)?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as any other literal:
UNMASK(x) == None
->x == UNMASK(None)
. However, this shouldn't be done in practice sinceANYTHING == None
always returns NULL.Shouldn't be, since unmasking is only done to convert the scanned table data (which is protected) to its unprotected form. If such a pattern occurs, then that means a bug has happened somewhere.
Because masking/unmasking preserves equality/inequality, which isn't true for others. For example:
UNMASK(column) == 16
is equivalent tocolumn == MASK(16)
UNMASK(column) < 16
is NOT the same ascolumn < MASK(16)
, since the mask/unmask does not preserve ordinality of values