-
Notifications
You must be signed in to change notification settings - Fork 3
Add masked table column relational unprotect rewrite #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
83aaa52
3930b3d
4a02e2a
32e678e
957f010
2596a00
f2e46c4
86036b7
f896f14
c1315e3
48b892b
040d725
16de5a3
5241003
2a27c99
c88d2f0
442621e
e5b8ab8
ffbe3fe
d352175
aa2ee68
f09d0e7
76a16c2
be00e58
fa2d869
2189fb3
0f6e59f
98a9c4c
bf2b075
a883759
5d273c3
2d69928
bc09e3f
ab08ce4
a36fb2b
df477e7
cccbe19
600492a
82e9691
2a24514
aea501f
a733022
a5c3b9c
bdde458
4a58775
66f2193
630c7cc
f4c318f
478b2f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Special operators containing logic to mask or unmask data based on a masked | ||
table column's metadata. | ||
""" | ||
|
||
__all__ = ["MaskedExpressionFunctionOperator"] | ||
|
||
|
||
from pydough.metadata.properties import MaskedTableColumnMetadata | ||
from pydough.pydough_operators.type_inference import ( | ||
ConstantType, | ||
ExpressionTypeDeducer, | ||
RequireNumArgs, | ||
TypeVerifier, | ||
) | ||
from pydough.types import PyDoughType | ||
|
||
from .expression_function_operators import ExpressionFunctionOperator | ||
|
||
|
||
class MaskedExpressionFunctionOperator(ExpressionFunctionOperator): | ||
""" | ||
A special expression function operator that masks or unmasks data based on | ||
a masked table column's metadata. The operator contains the metadata for | ||
the column, but can represent either a masking or unmasking operation | ||
depending on the `is_unmask` flag. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
masking_metadata: MaskedTableColumnMetadata, | ||
is_unmask: bool, | ||
): | ||
# Create a dummy verifier that requires exactly one argument, since all | ||
# masking/unmasking operations are unary. | ||
verifier: TypeVerifier = RequireNumArgs(1) | ||
Comment on lines
+34
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that a guarantee for all different use cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes since, from a PyDough perspective, the operator is always invoked the form |
||
|
||
# Create a dummy deducer that always returns the appropriate data type | ||
# from the metadata based on whether this is a masking or unmasking | ||
# operation. | ||
target_type: PyDoughType = ( | ||
masking_metadata.unprotected_data_type | ||
if is_unmask | ||
else masking_metadata.data_type | ||
) | ||
deducer: ExpressionTypeDeducer = ConstantType(target_type) | ||
|
||
super().__init__( | ||
"UNMASK" if is_unmask else "MASK", False, verifier, deducer, False | ||
) | ||
self._masking_metadata: MaskedTableColumnMetadata = masking_metadata | ||
self._is_unmask: bool = is_unmask | ||
|
||
@property | ||
def masking_metadata(self) -> MaskedTableColumnMetadata: | ||
""" | ||
The metadata for the masked column. | ||
""" | ||
return self._masking_metadata | ||
|
||
@property | ||
def is_unmask(self) -> bool: | ||
""" | ||
Whether this operator is unprotecting (True) or protecting (False). | ||
""" | ||
return self._is_unmask | ||
|
||
@property | ||
def format_string(self) -> str: | ||
""" | ||
The format string to use for this operator to either mask or unmask the | ||
operand. | ||
""" | ||
return ( | ||
self.masking_metadata.unprotect_protocol | ||
if self.is_unmask | ||
else self.masking_metadata.protect_protocol | ||
) | ||
|
||
def to_string(self, arg_strings: list[str]) -> str: | ||
name: str = "UNMASK" if self.is_unmask else "MASK" | ||
arg_strings = [f"[{s}]" for s in arg_strings] | ||
return f"{name}::({self.format_string.format(*arg_strings)})" | ||
|
||
def equals(self, other: object) -> bool: | ||
return ( | ||
isinstance(other, MaskedExpressionFunctionOperator) | ||
and self.masking_metadata == other.masking_metadata | ||
and self.is_unmask == other.is_unmask | ||
and super().equals(other) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -172,12 +172,24 @@ def convert_call_to_sqlglot( | |
return sqlglot_expressions.Anonymous( | ||
this=operator.sql_function_alias, expressions=args | ||
) | ||
if isinstance(operator, pydop.SqlMacroExpressionFunctionOperator): | ||
if isinstance( | ||
operator, | ||
( | ||
pydop.MaskedExpressionFunctionOperator, | ||
pydop.SqlMacroExpressionFunctionOperator, | ||
), | ||
): | ||
Comment on lines
+175
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just extending the logic we already use for UDFs with macro text, but now with the new |
||
# For user defined operators that are a macro for SQL text, convert | ||
# the arguments to SQL text strings then inject them into the macro | ||
# as a format string, then re-parse it. | ||
# as a format string, then re-parse it. The same idea works for the | ||
# masking/unmasking operators | ||
arg_strings: list[str] = [arg.sql() for arg in args] | ||
combined_string: str = operator.macro_text.format(*arg_strings) | ||
fmt_string: str | ||
if isinstance(operator, pydop.MaskedExpressionFunctionOperator): | ||
fmt_string = operator.format_string | ||
else: | ||
fmt_string = operator.macro_text | ||
combined_string: str = fmt_string.format(*arg_strings) | ||
return parse_one(combined_string) | ||
match operator: | ||
case pydop.NOT: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,14 +37,14 @@ CREATE TABLE TRANSACTIONS ( | |
); | ||
|
||
INSERT INTO CUSTOMERS (c_key, c_fname, c_lname, c_phone, c_email, c_addr, c_birthday) | ||
SELECT * | ||
-- 42 - column1, -- ARITHMETIC SHIFT: 42 | ||
-- UPPER(column2), -- UPPERCASE | ||
-- UPPER(column3), -- UPPERCASE | ||
-- REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9 | ||
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE | ||
-- SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE | ||
-- DATE(column7, '-472 days') -- DAY SHIFT: 472 | ||
SELECT | ||
42 - column1, -- ARITHMETIC SHIFT: 42 | ||
UPPER(column2), -- UPPERCASE | ||
UPPER(column3), -- UPPERCASE | ||
REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9 | ||
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE | ||
SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE | ||
DATE(column7, '-472 days') -- DAY SHIFT: 472 | ||
Comment on lines
-40
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forgot to restore encryption in the original PR, but its fine bc the E2E tests were being skipped until now |
||
FROM ( | ||
VALUES | ||
(1, 'alice', 'johnson', '555-123-4567', '[email protected]', '123 Maple St;Portland;OR;97205', '1985-04-12'), | ||
|
@@ -81,13 +81,13 @@ INSERT INTO BRANCHES (b_key, b_name, b_addr) VALUES | |
; | ||
|
||
INSERT INTO ACCOUNTS (a_key, a_custkey, a_branchkey, a_balance, a_type, a_open_ts) | ||
SELECT * | ||
-- CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER), | ||
-- column2, | ||
-- column3, | ||
-- column4 * column4, -- GEOMETRIC SHIFT | ||
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE | ||
-- DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789 | ||
SELECT | ||
CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER), | ||
column2, | ||
column3, | ||
column4 * column4, -- GEOMETRIC SHIFT | ||
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE | ||
DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789 | ||
FROM ( | ||
VALUES | ||
-- Customer 1 (alice johnson, OR) - 3 accounts | ||
|
@@ -189,12 +189,11 @@ VALUES | |
|
||
INSERT INTO TRANSACTIONS (t_key, t_sourceaccount, t_destaccount, t_amount, t_ts) | ||
SELECT | ||
* | ||
-- column1, | ||
-- column2, | ||
-- column3, | ||
-- 1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67 | ||
-- DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321 | ||
column1, | ||
column2, | ||
column3, | ||
1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67 | ||
DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321 | ||
FROM ( | ||
VALUES | ||
(1, 41, 8, 2753.92, '2019-11-11 18:00:52'), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -624,12 +624,15 @@ def test_pipeline_until_relational_cryptbank( | |
masked_graphs: graph_fetcher, | ||
get_plan_test_filename: Callable[[str], str], | ||
update_tests: bool, | ||
enable_mask_rewrites: str, | ||
) -> None: | ||
""" | ||
Tests the conversion of the PyDough queries on the custom cryptbank dataset | ||
into relational plans. | ||
""" | ||
file_path: str = get_plan_test_filename(cryptbank_pipeline_test_data.test_name) | ||
file_path: str = get_plan_test_filename( | ||
f"{cryptbank_pipeline_test_data.test_name}_{enable_mask_rewrites}" | ||
) | ||
cryptbank_pipeline_test_data.run_relational_test( | ||
masked_graphs, file_path, update_tests | ||
) | ||
|
@@ -641,13 +644,15 @@ def test_pipeline_until_sql_cryptbank( | |
sqlite_tpch_db_context: DatabaseContext, | ||
get_sql_test_filename: Callable[[str, DatabaseDialect], str], | ||
update_tests: bool, | ||
enable_mask_rewrites: str, | ||
): | ||
""" | ||
Tests the conversion of the PyDough queries on the custom cryptbank dataset | ||
into SQL text. | ||
""" | ||
file_path: str = get_sql_test_filename( | ||
cryptbank_pipeline_test_data.test_name, sqlite_tpch_db_context.dialect | ||
f"{cryptbank_pipeline_test_data.test_name}_{enable_mask_rewrites}", | ||
sqlite_tpch_db_context.dialect, | ||
) | ||
cryptbank_pipeline_test_data.run_sql_test( | ||
masked_graphs, | ||
|
@@ -657,14 +662,12 @@ def test_pipeline_until_sql_cryptbank( | |
) | ||
|
||
|
||
# @pytest.mark.skip( | ||
# reason="Skipping until masked table column relational handling is implemented" | ||
# ) | ||
Comment on lines
-660
to
-662
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forgot to restore the skip marker previously, but this was fine because the data wasn't being encrypted before. Now we are making it so it is encrypted (see |
||
@pytest.mark.execute | ||
def test_pipeline_e2e_cryptbank( | ||
cryptbank_pipeline_test_data: PyDoughPandasTest, | ||
masked_graphs: graph_fetcher, | ||
sqlite_cryptbank_connection: DatabaseContext, | ||
enable_mask_rewrites: str, | ||
): | ||
""" | ||
Test executing the the custom queries with the custom cryptbank dataset | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1226,5 +1226,4 @@ | |
} | ||
] | ||
} | ||
|
||
] |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
ROOT(columns=[('n', ROUND(avg_unmask_t_amount, 2:numeric))], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'avg_unmask_t_amount': AVG(UNMASK::((1025.67 - ([t_amount]))))}) | ||
FILTER(condition=MONTH(UNMASK::(DATETIME([t_ts], '+54321 seconds'))) == 6:numeric & YEAR(UNMASK::(DATETIME([t_ts], '+54321 seconds'))) == 2022:numeric, columns={'t_amount': t_amount}) | ||
SCAN(table=CRBNK.TRANSACTIONS, columns={'t_amount': t_amount, 't_ts': t_ts}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
ROOT(columns=[('n', ROUND(avg_unmask_t_amount, 2:numeric))], orderings=[]) | ||
AGGREGATE(keys={}, aggregations={'avg_unmask_t_amount': AVG(UNMASK::((1025.67 - ([t_amount]))))}) | ||
FILTER(condition=MONTH(UNMASK::(DATETIME([t_ts], '+54321 seconds'))) == 6:numeric & YEAR(UNMASK::(DATETIME([t_ts], '+54321 seconds'))) == 2022:numeric, columns={'t_amount': t_amount}) | ||
SCAN(table=CRBNK.TRANSACTIONS, columns={'t_amount': t_amount, 't_ts': t_ts}) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ROOT(columns=[('account_type', unmask_a_type), ('n', n_rows), ('avg_bal', ROUND(avg_unmask_a_balance, 2:numeric))], orderings=[]) | ||
AGGREGATE(keys={'unmask_a_type': UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))}, aggregations={'avg_unmask_a_balance': AVG(UNMASK::(SQRT([a_balance]))), 'n_rows': COUNT()}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_type': a_type}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ROOT(columns=[('account_type', unmask_a_type), ('n', n_rows), ('avg_bal', ROUND(avg_unmask_a_balance, 2:numeric))], orderings=[]) | ||
AGGREGATE(keys={'unmask_a_type': UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))}, aggregations={'avg_unmask_a_balance': AVG(UNMASK::(SQRT([a_balance]))), 'n_rows': COUNT()}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_type': a_type}) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
ROOT(columns=[('account_type', UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))), ('balance', UNMASK::(SQRT([a_balance]))), ('name', JOIN_STRINGS(' ':string, UNMASK::(LOWER([c_fname])), UNMASK::(LOWER([c_lname]))))], orderings=[]) | ||
FILTER(condition=RANKING(args=[], partition=[UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))], order=[(UNMASK::(SQRT([a_balance]))):desc_first], allow_ties=False) == 1:numeric, columns={'a_balance': a_balance, 'a_type': a_type, 'c_fname': c_fname, 'c_lname': c_lname}) | ||
JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_ACCESS, columns={'a_balance': t0.a_balance, 'a_type': t0.a_type, 'c_fname': t1.c_fname, 'c_lname': t1.c_lname}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_custkey': a_custkey, 'a_type': a_type}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_fname': c_fname, 'c_key': c_key, 'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
ROOT(columns=[('account_type', UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))), ('balance', UNMASK::(SQRT([a_balance]))), ('name', JOIN_STRINGS(' ':string, UNMASK::(LOWER([c_fname])), UNMASK::(LOWER([c_lname]))))], orderings=[]) | ||
FILTER(condition=RANKING(args=[], partition=[UNMASK::(SUBSTRING([a_type], -1) || SUBSTRING([a_type], 1, LENGTH([a_type]) - 1))], order=[(UNMASK::(SQRT([a_balance]))):desc_first], allow_ties=False) == 1:numeric, columns={'a_balance': a_balance, 'a_type': a_type, 'c_fname': c_fname, 'c_lname': c_lname}) | ||
JOIN(condition=t0.a_custkey == UNMASK::((42 - ([t1.c_key]))), type=INNER, cardinality=SINGULAR_ACCESS, columns={'a_balance': t0.a_balance, 'a_type': t0.a_type, 'c_fname': t1.c_fname, 'c_lname': t1.c_lname}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_custkey': a_custkey, 'a_type': a_type}) | ||
SCAN(table=CRBNK.CUSTOMERS, columns={'c_fname': c_fname, 'c_key': c_key, 'c_lname': c_lname}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
ROOT(columns=[('branch_key', b_key), ('pct_total_wealth', ROUND(DEFAULT_TO(sum_a_balance, 0:numeric) / RELSUM(args=[DEFAULT_TO(sum_a_balance, 0:numeric)], partition=[], order=[]), 2:numeric))], orderings=[]) | ||
JOIN(condition=t0.b_key == t1.a_branchkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'b_key': t0.b_key, 'sum_a_balance': t1.sum_a_balance}) | ||
ROOT(columns=[('branch_key', b_key), ('pct_total_wealth', ROUND(DEFAULT_TO(sum_unmask_a_balance, 0:numeric) / RELSUM(args=[DEFAULT_TO(sum_unmask_a_balance, 0:numeric)], partition=[], order=[]), 2:numeric))], orderings=[]) | ||
JOIN(condition=t0.b_key == t1.a_branchkey, type=INNER, cardinality=SINGULAR_ACCESS, columns={'b_key': t0.b_key, 'sum_unmask_a_balance': t1.sum_unmask_a_balance}) | ||
SCAN(table=CRBNK.BRANCHES, columns={'b_key': b_key}) | ||
AGGREGATE(keys={'a_branchkey': a_branchkey}, aggregations={'sum_a_balance': SUM(a_balance)}) | ||
AGGREGATE(keys={'a_branchkey': a_branchkey}, aggregations={'sum_unmask_a_balance': SUM(UNMASK::(SQRT([a_balance])))}) | ||
SCAN(table=CRBNK.ACCOUNTS, columns={'a_balance': a_balance, 'a_branchkey': a_branchkey}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these assert required? We are here because
is_masked_column(hybrid_expr) == true
. Am I missing something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are required for
mypy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, you are right.
Do something like
def is_masked_column(self, expr: HybridExpr) -> TypeGuard[HybridColumnExpr]:
for the return type will be OK with mypy? so we don't need extra assert code just for that?I mean, it is OK since both asserts will be always true... I was just thinking there should be way to do that for mypy.