Skip to content

Commit 1be9e87

Browse files
authored
Add masked table column relational unprotect rewrite (#417)
Adding crucial rewrite step for masked table columns: During relational conversion, if a column is a masked table column, place a `PROJECT` on top of the `SCAN` node fetching data from the table. This `PROJECT` node will invoke an `UNMASK` operator (containing information from the metadata for the masked table column) which will transform the masked columns into their unmasked forms. Updates existing tests to account for this transformation, ensuring that the `UNMASK` calls are injected into the underlying SQL (and pulled up as late as possible by projection pullup). All the tests in `test_masked_sqlite.py` will now run with the _correct_ e2e answers, as the decryption of the underlying data is now undone by the generated query.
1 parent 64d9f84 commit 1be9e87

File tree

380 files changed

+2460
-955
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

380 files changed

+2460
-955
lines changed

pydough/conversion/relational_converter.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pydough.metadata import (
1616
CartesianProductMetadata,
1717
GeneralJoinMetadata,
18+
MaskedTableColumnMetadata,
1819
SimpleJoinMetadata,
1920
SimpleTableMetadata,
2021
)
@@ -765,6 +766,21 @@ def handle_children(
765766
context.expressions[hybrid_ref] = context.expressions[key_expr]
766767
return context
767768

769+
def is_masked_column(self, expr: HybridExpr) -> bool:
770+
"""
771+
Checks if a given expression is a masked column expression.
772+
773+
Args:
774+
`expr`: the expression to check.
775+
776+
Returns:
777+
True if the expression is a masked column expression, False
778+
otherwise.
779+
"""
780+
return isinstance(expr, HybridColumnExpr) and isinstance(
781+
expr.column.column_property, MaskedTableColumnMetadata
782+
)
783+
768784
def build_simple_table_scan(
769785
self, node: HybridCollectionAccess
770786
) -> TranslationOutput:
@@ -806,7 +822,31 @@ def build_simple_table_scan(
806822
assert isinstance(expr, ColumnReference)
807823
real_names.add(expr.name)
808824
uniqueness.add(frozenset(real_names))
809-
answer = Scan(node.collection.collection.table_path, scan_columns, uniqueness)
825+
answer: RelationalNode = Scan(
826+
node.collection.collection.table_path, scan_columns, uniqueness
827+
)
828+
829+
# If any of the columns are masked, insert a projection on top to unmask
830+
# them.
831+
if any(self.is_masked_column(expr) for expr in node.terms.values()):
832+
unmask_columns: dict[str, RelationalExpression] = {}
833+
for name, hybrid_expr in node.terms.items():
834+
if self.is_masked_column(hybrid_expr):
835+
assert isinstance(hybrid_expr, HybridColumnExpr)
836+
assert isinstance(
837+
hybrid_expr.column.column_property, MaskedTableColumnMetadata
838+
)
839+
unmask_columns[name] = CallExpression(
840+
pydop.MaskedExpressionFunctionOperator(
841+
hybrid_expr.column.column_property, True
842+
),
843+
hybrid_expr.column.column_property.unprotected_data_type,
844+
[ColumnReference(name, hybrid_expr.typ)],
845+
)
846+
else:
847+
unmask_columns[name] = ColumnReference(name, hybrid_expr.typ)
848+
answer = Project(answer, unmask_columns)
849+
810850
return TranslationOutput(answer, out_columns)
811851

812852
def translate_sub_collection(

pydough/pydough_operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"MONOTONIC",
6262
"MONTH",
6363
"MUL",
64+
"MaskedExpressionFunctionOperator",
6465
"NDISTINCT",
6566
"NEQ",
6667
"NEXT",
@@ -207,6 +208,7 @@
207208
ExpressionFunctionOperator,
208209
ExpressionWindowOperator,
209210
KeywordBranchingExpressionFunctionOperator,
211+
MaskedExpressionFunctionOperator,
210212
PyDoughExpressionOperator,
211213
SqlAliasExpressionFunctionOperator,
212214
SqlMacroExpressionFunctionOperator,

pydough/pydough_operators/expression_operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"MONOTONIC",
5959
"MONTH",
6060
"MUL",
61+
"MaskedExpressionFunctionOperator",
6162
"NDISTINCT",
6263
"NEQ",
6364
"NEXT",
@@ -107,6 +108,7 @@
107108
from .expression_operator import PyDoughExpressionOperator
108109
from .expression_window_operators import ExpressionWindowOperator
109110
from .keyword_branching_operators import KeywordBranchingExpressionFunctionOperator
111+
from .masked_expression_function_operator import MaskedExpressionFunctionOperator
110112
from .registered_expression_operators import (
111113
ABS,
112114
ABSENT,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Special operators containing logic to mask or unmask data based on a masked
3+
table column's metadata.
4+
"""
5+
6+
__all__ = ["MaskedExpressionFunctionOperator"]
7+
8+
9+
from pydough.metadata.properties import MaskedTableColumnMetadata
10+
from pydough.pydough_operators.type_inference import (
11+
ConstantType,
12+
ExpressionTypeDeducer,
13+
RequireNumArgs,
14+
TypeVerifier,
15+
)
16+
from pydough.types import PyDoughType
17+
18+
from .expression_function_operators import ExpressionFunctionOperator
19+
20+
21+
class MaskedExpressionFunctionOperator(ExpressionFunctionOperator):
22+
"""
23+
A special expression function operator that masks or unmasks data based on
24+
a masked table column's metadata. The operator contains the metadata for
25+
the column, but can represent either a masking or unmasking operation
26+
depending on the `is_unmask` flag.
27+
"""
28+
29+
def __init__(
30+
self,
31+
masking_metadata: MaskedTableColumnMetadata,
32+
is_unmask: bool,
33+
):
34+
# Create a dummy verifier that requires exactly one argument, since all
35+
# masking/unmasking operations are unary.
36+
verifier: TypeVerifier = RequireNumArgs(1)
37+
38+
# Create a dummy deducer that always returns the appropriate data type
39+
# from the metadata based on whether this is a masking or unmasking
40+
# operation.
41+
target_type: PyDoughType = (
42+
masking_metadata.unprotected_data_type
43+
if is_unmask
44+
else masking_metadata.data_type
45+
)
46+
deducer: ExpressionTypeDeducer = ConstantType(target_type)
47+
48+
super().__init__(
49+
"UNMASK" if is_unmask else "MASK", False, verifier, deducer, False
50+
)
51+
self._masking_metadata: MaskedTableColumnMetadata = masking_metadata
52+
self._is_unmask: bool = is_unmask
53+
54+
@property
55+
def masking_metadata(self) -> MaskedTableColumnMetadata:
56+
"""
57+
The metadata for the masked column.
58+
"""
59+
return self._masking_metadata
60+
61+
@property
62+
def is_unmask(self) -> bool:
63+
"""
64+
Whether this operator is unprotecting (True) or protecting (False).
65+
"""
66+
return self._is_unmask
67+
68+
@property
69+
def format_string(self) -> str:
70+
"""
71+
The format string to use for this operator to either mask or unmask the
72+
operand.
73+
"""
74+
return (
75+
self.masking_metadata.unprotect_protocol
76+
if self.is_unmask
77+
else self.masking_metadata.protect_protocol
78+
)
79+
80+
def to_string(self, arg_strings: list[str]) -> str:
81+
name: str = "UNMASK" if self.is_unmask else "MASK"
82+
arg_strings = [f"[{s}]" for s in arg_strings]
83+
return f"{name}::({self.format_string.format(*arg_strings)})"
84+
85+
def equals(self, other: object) -> bool:
86+
return (
87+
isinstance(other, MaskedExpressionFunctionOperator)
88+
and self.masking_metadata == other.masking_metadata
89+
and self.is_unmask == other.is_unmask
90+
and super().equals(other)
91+
)

pydough/sqlglot/transform_bindings/base_transform_bindings.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,24 @@ def convert_call_to_sqlglot(
172172
return sqlglot_expressions.Anonymous(
173173
this=operator.sql_function_alias, expressions=args
174174
)
175-
if isinstance(operator, pydop.SqlMacroExpressionFunctionOperator):
175+
if isinstance(
176+
operator,
177+
(
178+
pydop.MaskedExpressionFunctionOperator,
179+
pydop.SqlMacroExpressionFunctionOperator,
180+
),
181+
):
176182
# For user defined operators that are a macro for SQL text, convert
177183
# the arguments to SQL text strings then inject them into the macro
178-
# as a format string, then re-parse it.
184+
# as a format string, then re-parse it. The same idea works for the
185+
# masking/unmasking operators
179186
arg_strings: list[str] = [arg.sql() for arg in args]
180-
combined_string: str = operator.macro_text.format(*arg_strings)
187+
fmt_string: str
188+
if isinstance(operator, pydop.MaskedExpressionFunctionOperator):
189+
fmt_string = operator.format_string
190+
else:
191+
fmt_string = operator.macro_text
192+
combined_string: str = fmt_string.format(*arg_strings)
181193
return parse_one(combined_string)
182194
match operator:
183195
case pydop.NOT:

tests/gen_data/init_cryptbank.sql

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ CREATE TABLE TRANSACTIONS (
3737
);
3838

3939
INSERT INTO CUSTOMERS (c_key, c_fname, c_lname, c_phone, c_email, c_addr, c_birthday)
40-
SELECT *
41-
-- 42 - column1, -- ARITHMETIC SHIFT: 42
42-
-- UPPER(column2), -- UPPERCASE
43-
-- UPPER(column3), -- UPPERCASE
44-
-- REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9
45-
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
46-
-- SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE
47-
-- DATE(column7, '-472 days') -- DAY SHIFT: 472
40+
SELECT
41+
42 - column1, -- ARITHMETIC SHIFT: 42
42+
UPPER(column2), -- UPPERCASE
43+
UPPER(column3), -- UPPERCASE
44+
REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9
45+
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
46+
SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE
47+
DATE(column7, '-472 days') -- DAY SHIFT: 472
4848
FROM (
4949
VALUES
5050
(1, 'alice', 'johnson', '555-123-4567', '[email protected]', '123 Maple St;Portland;OR;97205', '1985-04-12'),
@@ -81,13 +81,13 @@ INSERT INTO BRANCHES (b_key, b_name, b_addr) VALUES
8181
;
8282

8383
INSERT INTO ACCOUNTS (a_key, a_custkey, a_branchkey, a_balance, a_type, a_open_ts)
84-
SELECT *
85-
-- CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER),
86-
-- column2,
87-
-- column3,
88-
-- column4 * column4, -- GEOMETRIC SHIFT
89-
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
90-
-- DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789
84+
SELECT
85+
CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER),
86+
column2,
87+
column3,
88+
column4 * column4, -- GEOMETRIC SHIFT
89+
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
90+
DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789
9191
FROM (
9292
VALUES
9393
-- Customer 1 (alice johnson, OR) - 3 accounts
@@ -189,12 +189,11 @@ VALUES
189189

190190
INSERT INTO TRANSACTIONS (t_key, t_sourceaccount, t_destaccount, t_amount, t_ts)
191191
SELECT
192-
*
193-
-- column1,
194-
-- column2,
195-
-- column3,
196-
-- 1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67
197-
-- DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321
192+
column1,
193+
column2,
194+
column3,
195+
1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67
196+
DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321
198197
FROM (
199198
VALUES
200199
(1, 41, 8, 2753.92, '2019-11-11 18:00:52'),

tests/test_masked_sf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -870,9 +870,6 @@ def test_pipeline_until_sql_masked_sf(
870870
)
871871

872872

873-
@pytest.mark.skip(
874-
reason="Skipping until masked table column relational handling is implemented"
875-
)
876873
@pytest.mark.execute
877874
@pytest.mark.sf_masked
878875
@pytest.mark.parametrize("account_type", ["NONE", "PARTIAL", "FULL"])

tests/test_masked_sqlite.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,15 @@ def test_pipeline_until_relational_cryptbank(
624624
masked_graphs: graph_fetcher,
625625
get_plan_test_filename: Callable[[str], str],
626626
update_tests: bool,
627+
enable_mask_rewrites: str,
627628
) -> None:
628629
"""
629630
Tests the conversion of the PyDough queries on the custom cryptbank dataset
630631
into relational plans.
631632
"""
632-
file_path: str = get_plan_test_filename(cryptbank_pipeline_test_data.test_name)
633+
file_path: str = get_plan_test_filename(
634+
f"{cryptbank_pipeline_test_data.test_name}_{enable_mask_rewrites}"
635+
)
633636
cryptbank_pipeline_test_data.run_relational_test(
634637
masked_graphs, file_path, update_tests
635638
)
@@ -641,13 +644,15 @@ def test_pipeline_until_sql_cryptbank(
641644
sqlite_tpch_db_context: DatabaseContext,
642645
get_sql_test_filename: Callable[[str, DatabaseDialect], str],
643646
update_tests: bool,
647+
enable_mask_rewrites: str,
644648
):
645649
"""
646650
Tests the conversion of the PyDough queries on the custom cryptbank dataset
647651
into SQL text.
648652
"""
649653
file_path: str = get_sql_test_filename(
650-
cryptbank_pipeline_test_data.test_name, sqlite_tpch_db_context.dialect
654+
f"{cryptbank_pipeline_test_data.test_name}_{enable_mask_rewrites}",
655+
sqlite_tpch_db_context.dialect,
651656
)
652657
cryptbank_pipeline_test_data.run_sql_test(
653658
masked_graphs,
@@ -657,14 +662,12 @@ def test_pipeline_until_sql_cryptbank(
657662
)
658663

659664

660-
# @pytest.mark.skip(
661-
# reason="Skipping until masked table column relational handling is implemented"
662-
# )
663665
@pytest.mark.execute
664666
def test_pipeline_e2e_cryptbank(
665667
cryptbank_pipeline_test_data: PyDoughPandasTest,
666668
masked_graphs: graph_fetcher,
667669
sqlite_cryptbank_connection: DatabaseContext,
670+
enable_mask_rewrites: str,
668671
):
669672
"""
670673
Test executing the the custom queries with the custom cryptbank dataset

tests/test_metadata/sf_masked_examples.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,5 +1226,4 @@
12261226
}
12271227
]
12281228
}
1229-
12301229
]

tests/test_plan_refsols/cryptbank_agg_01.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)