Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
83aaa52
Initial setup started
knassre-bodo Aug 21, 2025
3930b3d
Added all metadata except the protect/unprotect protocols for CRYPTBANK
knassre-bodo Aug 21, 2025
4a02e2a
Added basic tests before inclusion of encryption
knassre-bodo Aug 22, 2025
32e678e
Added more test files
knassre-bodo Aug 22, 2025
957f010
Renamed tests
knassre-bodo Aug 22, 2025
2596a00
Added new tests [RUN CI]
knassre-bodo Aug 22, 2025
f2e46c4
Re-enabling encryption of CRYPTBANK data and skipping e2e tests until…
knassre-bodo Aug 22, 2025
86036b7
Adidng more tests [RUN CI]
knassre-bodo Aug 25, 2025
f896f14
[RUN CI]
knassre-bodo Aug 25, 2025
c1315e3
Added initial relational setup with operator for unmasking
knassre-bodo Aug 26, 2025
48b892b
Fixing naming bug
knassre-bodo Aug 26, 2025
040d725
Added cryptbank SQL support with encryptions injected
knassre-bodo Aug 26, 2025
16de5a3
[RUN CI]
knassre-bodo Aug 26, 2025
5241003
Fixing JSON file [RUN CI]
knassre-bodo Aug 26, 2025
2a27c99
Merge branch 'kian/sqlite_masked_tests' into kian/masked_relational_r…
knassre-bodo Aug 26, 2025
9e50717
Initial implementation in progress
knassre-bodo Aug 26, 2025
c88d2f0
Resolving conflicts
knassre-bodo Aug 27, 2025
442621e
Merge branch 'kian/sqlite_masked_tests' into kian/masked_relational_r…
knassre-bodo Aug 27, 2025
1fe90af
Merge branch 'kian/masked_relational_rewrite' into kian/mask_literal_…
knassre-bodo Aug 27, 2025
e5b8ab8
Resolving conflicts [RUN CI]
knassre-bodo Sep 8, 2025
2313b57
Resolving conflicts [RUN CI]
knassre-bodo Sep 8, 2025
ffbe3fe
Merge branch 'main' into kian/masked_relational_rewrite
knassre-bodo Sep 8, 2025
d352175
add rest
hadia206 Sep 8, 2025
aa2ee68
sf_masked_examples.json
hadia206 Sep 8, 2025
f09d0e7
Revisions [RUN CI]
knassre-bodo Sep 9, 2025
76a16c2
Merge branch 'main' into kian/masked_relational_rewrite
knassre-bodo Sep 10, 2025
be00e58
Merge branch 'main' into kian/masked_relational_rewrite
knassre-bodo Sep 15, 2025
fa2d869
[RUN CI]
knassre-bodo Sep 15, 2025
fed46d5
Merge branch 'kian/masked_relational_rewrite' into kian/mask_literal_…
knassre-bodo Sep 15, 2025
655054f
Resolving conflicts
knassre-bodo Sep 15, 2025
874dcad
Revisions WIP
knassre-bodo Sep 15, 2025
2189fb3
Merge branch 'main' into kian/masked_relational_rewrite
knassre-bodo Sep 17, 2025
28b58ce
Merge branch 'main' into kian/mask_literal_rewrite
knassre-bodo Sep 17, 2025
4536992
Merge branch 'kian/masked_relational_rewrite' into kian/mask_literal_…
knassre-bodo Sep 17, 2025
e2fe6b7
Adding environment variable and doubling cryptbank tests to case on it
knassre-bodo Sep 17, 2025
0f6e59f
Adding environment variable
knassre-bodo Sep 17, 2025
98a9c4c
[RUN CI]
knassre-bodo Sep 17, 2025
46b1c36
Resolving conflicts [RUN CI]
knassre-bodo Sep 17, 2025
bf2b075
add sql and relational files and tests
hadia206 Sep 19, 2025
a883759
use other version in some metadata and skip tests
hadia206 Sep 19, 2025
5d273c3
add import deleted by ruff
hadia206 Sep 19, 2025
2d69928
merge
hadia206 Sep 22, 2025
bc09e3f
Github action
hadia206 Sep 22, 2025
ab08ce4
Merge branch 'main' into kian/masked_relational_rewrite
knassre-bodo Sep 24, 2025
a36fb2b
[run CI] address comments (remove test and add type hints)
hadia206 Sep 24, 2025
df477e7
Revisions
knassre-bodo Sep 24, 2025
cccbe19
[RUN CI]
knassre-bodo Sep 24, 2025
61194bb
Merge branch 'kian/masked_relational_rewrite' into kian/mask_literal_…
knassre-bodo Sep 24, 2025
1dfe201
revisions
knassre-bodo Sep 25, 2025
600492a
Merge remote-tracking branch 'origin/Hadia/sf_masked_tests' into kian…
knassre-bodo Sep 25, 2025
82e9691
Resolving conflicts, adding raw vs rewrite
knassre-bodo Sep 25, 2025
2a24514
Adding raw vs rewrite
knassre-bodo Sep 25, 2025
aea501f
Fixing SQL handling and fixtures
knassre-bodo Sep 25, 2025
a733022
Resolving conflicts
knassre-bodo Sep 25, 2025
a5c3b9c
WIP
knassre-bodo Sep 25, 2025
bdde458
Adding more tests
knassre-bodo Sep 29, 2025
4a58775
Adding more tests
knassre-bodo Sep 29, 2025
66f2193
Resolving test updates
knassre-bodo Sep 29, 2025
630c7cc
Adding more tests
knassre-bodo Sep 29, 2025
f4c318f
Resolving conflicts [RUN ALL]
knassre-bodo Sep 29, 2025
c8ade74
Merge branch 'kian/masked_relational_rewrite' into kian/mask_literal_…
knassre-bodo Sep 29, 2025
857c39e
Updating files
knassre-bodo Sep 29, 2025
89fa4a6
Updating other fails
knassre-bodo Sep 29, 2025
ef3ae04
Resolving conflicts
knassre-bodo Sep 30, 2025
85376d7
[RUN CI][RUN SF_MASKED]
knassre-bodo Sep 30, 2025
4bf43ea
Merge branch 'main' into kian/mask_literal_rewrite
knassre-bodo Oct 1, 2025
db3ba6f
Resolving conflicts
knassre-bodo Oct 6, 2025
1b601ac
Revision
knassre-bodo Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions pydough/conversion/masking_shuttles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
TODO
"""

__all__ = ["MaskLiteralComparisonShuttle"]

import pydough.pydough_operators as pydop
from pydough.relational import (
CallExpression,
LiteralExpression,
RelationalExpression,
RelationalExpressionShuttle,
)


class MaskLiteralComparisonShuttle(RelationalExpressionShuttle):
"""
TODO
"""

def is_unprotect_call(self, expr: RelationalExpression) -> bool:
"""
TODO
"""
return (
isinstance(expr, CallExpression)
and isinstance(expr.op, pydop.MaskedExpressionFunctionOperator)
and expr.op.is_unprotect
)

def protect_literal_comparison(
self,
original_call: CallExpression,
call_arg: CallExpression,
literal_arg: LiteralExpression,
) -> CallExpression:
"""
TODO
"""
if (
not isinstance(call_arg.op, pydop.MaskedExpressionFunctionOperator)
or not call_arg.op.is_unprotect
):
return original_call

masked_literal: RelationalExpression

if original_call.op in (pydop.EQU, pydop.NEQ):
masked_literal = CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
),
call_arg.data_type,
[literal_arg],
)
elif original_call.op == pydop.ISIN and isinstance(
literal_arg.value, (list, tuple)
):
masked_literal = LiteralExpression(
[
CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
),
call_arg.data_type,
[LiteralExpression(v, literal_arg.data_type)],
)
for v in literal_arg.value
],
original_call.data_type,
)
else:
return original_call

return CallExpression(
original_call.op,
original_call.data_type,
[call_arg.inputs[0], masked_literal],
)

def visit_call_expression(
self, call_expression: CallExpression
) -> RelationalExpression:
if call_expression.op in (pydop.EQU, pydop.NEQ):
if isinstance(call_expression.inputs[0], CallExpression) and isinstance(
call_expression.inputs[1], LiteralExpression
):
call_expression = self.protect_literal_comparison(
call_expression,
call_expression.inputs[0],
call_expression.inputs[1],
)
if isinstance(call_expression.inputs[1], CallExpression) and isinstance(
call_expression.inputs[0], LiteralExpression
):
call_expression = self.protect_literal_comparison(
call_expression,
call_expression.inputs[1],
call_expression.inputs[0],
)
if (
call_expression.op == pydop.ISIN
and isinstance(call_expression.inputs[0], CallExpression)
and isinstance(call_expression.inputs[1], LiteralExpression)
):
call_expression = self.protect_literal_comparison(
call_expression, call_expression.inputs[0], call_expression.inputs[1]
)
return super().visit_call_expression(call_expression)
32 changes: 31 additions & 1 deletion pydough/conversion/relational_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydough.metadata import (
CartesianProductMetadata,
GeneralJoinMetadata,
MaskedTableColumnMetadata,
SimpleJoinMetadata,
SimpleTableMetadata,
)
Expand Down Expand Up @@ -84,6 +85,7 @@
)
from .hybrid_translator import HybridTranslator
from .hybrid_tree import HybridTree
from .masking_shuttles import MaskLiteralComparisonShuttle
from .merge_projects import merge_projects
from .projection_pullup import pullup_projections
from .relational_simplification import simplify_expressions
Expand Down Expand Up @@ -803,7 +805,33 @@ def build_simple_table_scan(
assert isinstance(expr, ColumnReference)
real_names.add(expr.name)
uniqueness.add(frozenset(real_names))
answer = Scan(node.collection.collection.table_path, scan_columns, uniqueness)
answer: RelationalNode = Scan(
node.collection.collection.table_path, scan_columns, uniqueness
)

# If any of the columns are masked, insert a projection on top to unmask
# them.
if any(
isinstance(expr, HybridColumnExpr)
and isinstance(expr.column.column_property, MaskedTableColumnMetadata)
for expr in node.terms.values()
):
unmask_columns: dict[str, RelationalExpression] = {}
for name, hybrid_expr in node.terms.items():
if isinstance(hybrid_expr, HybridColumnExpr) and isinstance(
hybrid_expr.column.column_property, MaskedTableColumnMetadata
):
unmask_columns[name] = CallExpression(
pydop.MaskedExpressionFunctionOperator(
hybrid_expr.column.column_property, True
),
hybrid_expr.column.column_property.unprotected_data_type,
[ColumnReference(name, hybrid_expr.typ)],
)
else:
unmask_columns[name] = ColumnReference(name, hybrid_expr.typ)
answer = Project(answer, unmask_columns)

return TranslationOutput(answer, out_columns)

def translate_sub_collection(
Expand Down Expand Up @@ -1579,6 +1607,8 @@ def convert_ast_to_relational(

# Invoke the optimization procedures on the result to clean up the tree.
additional_shuttles: list[RelationalExpressionShuttle] = []
if True:
additional_shuttles.append(MaskLiteralComparisonShuttle())
optimized_result: RelationalRoot = optimize_relational_tree(
raw_result, configs, additional_shuttles
)
Expand Down
2 changes: 2 additions & 0 deletions pydough/pydough_operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"MONOTONIC",
"MONTH",
"MUL",
"MaskedExpressionFunctionOperator",
"NDISTINCT",
"NEQ",
"NEXT",
Expand Down Expand Up @@ -207,6 +208,7 @@
ExpressionFunctionOperator,
ExpressionWindowOperator,
KeywordBranchingExpressionFunctionOperator,
MaskedExpressionFunctionOperator,
PyDoughExpressionOperator,
SqlAliasExpressionFunctionOperator,
SqlMacroExpressionFunctionOperator,
Expand Down
2 changes: 2 additions & 0 deletions pydough/pydough_operators/expression_operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"MONOTONIC",
"MONTH",
"MUL",
"MaskedExpressionFunctionOperator",
"NDISTINCT",
"NEQ",
"NEXT",
Expand Down Expand Up @@ -107,6 +108,7 @@
from .expression_operator import PyDoughExpressionOperator
from .expression_window_operators import ExpressionWindowOperator
from .keyword_branching_operators import KeywordBranchingExpressionFunctionOperator
from .masked_expression_function_operator import MaskedExpressionFunctionOperator
from .registered_expression_operators import (
ABS,
ABSENT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
TODO
"""

__all__ = ["MaskedExpressionFunctionOperator"]


from pydough.metadata.properties import MaskedTableColumnMetadata
from pydough.pydough_operators.type_inference import (
ConstantType,
ExpressionTypeDeducer,
RequireNumArgs,
TypeVerifier,
)

from .expression_function_operators import ExpressionFunctionOperator


class MaskedExpressionFunctionOperator(ExpressionFunctionOperator):
"""
TODO
"""

def __init__(
self,
masking_metadata: MaskedTableColumnMetadata,
is_unprotect: bool,
):
verifier: TypeVerifier = RequireNumArgs(1)
deducer: ExpressionTypeDeducer = ConstantType(
masking_metadata.unprotected_data_type
if is_unprotect
else masking_metadata.data_type
)
super().__init__(
"UNMASK" if is_unprotect else "MASK", False, verifier, deducer, False
)
self._masking_metadata: MaskedTableColumnMetadata = masking_metadata
self._is_unprotect: bool = is_unprotect

@property
def masking_metadata(self) -> MaskedTableColumnMetadata:
"""
The metadata for the masked column.
"""
return self._masking_metadata

@property
def is_unprotect(self) -> bool:
"""
Whether this operator is unprotecting (True) or protecting (False).
"""
return self._is_unprotect

@property
def format_string(self) -> str:
"""
The format string to use for this operator to either mask or unmask the
operand.
"""
return (
self.masking_metadata.unprotect_protocol
if self.is_unprotect
else self.masking_metadata.protect_protocol
)

def to_string(self, arg_strings: list[str]) -> str:
name: str = "UNMASK" if self.is_unprotect else "MASK"
arg_strings = [f"[{s}]" for s in arg_strings]
return f"{name}::({self.format_string.format(*arg_strings)})"

def equals(self, other: object) -> bool:
return (
isinstance(other, MaskedExpressionFunctionOperator)
and self.masking_metadata == other.masking_metadata
and self.is_unprotect == other.is_unprotect
and super().equals(other)
)
19 changes: 16 additions & 3 deletions pydough/sqlglot/sqlglot_relational_expression_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,22 @@ def visit_window_expression(self, window_expression: WindowCallExpression) -> No
def visit_literal_expression(self, literal_expression: LiteralExpression) -> None:
# Note: This assumes each literal has an associated type that can be parsed
# and types do not represent implicit casts.
literal: SQLGlotExpression = sqlglot_expressions.convert(
literal_expression.value
)
literal: SQLGlotExpression
if isinstance(literal_expression.value, (tuple, list)):
# If the literal is a list or tuple, convert each element
# individually and create an array literal.
Comment on lines +317 to +318
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because now, with the recent changes, a "list literal" can contain non-literal expressions, e.g. a list of function calls [MASK("foo"), MASK("bar")]

elements: list[SQLGlotExpression] = []
for element in literal_expression.value:
element_expr: SQLGlotExpression
if isinstance(element, RelationalExpression):
element.accept(self)
element_expr = self._stack.pop()
else:
element_expr = sqlglot_expressions.convert(element)
elements.append(element_expr)
literal = sqlglot_expressions.Array(expressions=elements)
else:
literal = sqlglot_expressions.convert(literal_expression.value)

# Special handling: insert cast calls for ansi casting of date/time
# instead of relying on SQLGlot conversion functions. This is because
Expand Down
18 changes: 15 additions & 3 deletions pydough/sqlglot/transform_bindings/base_transform_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,24 @@ def convert_call_to_sqlglot(
return sqlglot_expressions.Anonymous(
this=operator.sql_function_alias, expressions=args
)
if isinstance(operator, pydop.SqlMacroExpressionFunctionOperator):
if isinstance(
operator,
(
pydop.MaskedExpressionFunctionOperator,
pydop.SqlMacroExpressionFunctionOperator,
),
):
# For user defined operators that are a macro for SQL text, convert
# the arguments to SQL text strings then inject them into the macro
# as a format string, then re-parse it.
# as a format string, then re-parse it. The same idea works for the
# masking/unmasking operators
arg_strings: list[str] = [arg.sql() for arg in args]
combined_string: str = operator.macro_text.format(*arg_strings)
fmt_string: str
if isinstance(operator, pydop.MaskedExpressionFunctionOperator):
fmt_string = operator.format_string
else:
fmt_string = operator.macro_text
combined_string: str = fmt_string.format(*arg_strings)
return parse_one(combined_string)
match operator:
case pydop.NOT:
Expand Down
41 changes: 20 additions & 21 deletions tests/gen_data/init_cryptbank.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ CREATE TABLE TRANSACTIONS (
);

INSERT INTO CUSTOMERS (c_key, c_fname, c_lname, c_phone, c_email, c_addr, c_birthday)
SELECT *
-- 42 - column1, -- ARITHMETIC SHIFT: 42
-- UPPER(column2), -- UPPERCASE
-- UPPER(column3), -- UPPERCASE
-- REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
-- SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE
-- DATE(column7, '-472 days') -- DAY SHIFT: 472
SELECT
42 - column1, -- ARITHMETIC SHIFT: 42
UPPER(column2), -- UPPERCASE
UPPER(column3), -- UPPERCASE
REPLACE(REPLACE(REPLACE(column4, '0', '*'), '9', '0'), '*', '9'), -- DIGIT SWITCH: 0 <-> 9
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
SUBSTRING(column6, 2) || SUBSTRING(column6, 1, 1), -- FIRST CHAR TRANSPOSE
DATE(column7, '-472 days') -- DAY SHIFT: 472
FROM (
VALUES
(1, 'alice', 'johnson', '555-123-4567', '[email protected]', '123 Maple St;Portland;OR;97205', '1985-04-12'),
Expand Down Expand Up @@ -81,13 +81,13 @@ INSERT INTO BRANCHES (b_key, b_name, b_addr) VALUES
;

INSERT INTO ACCOUNTS (a_key, a_custkey, a_branchkey, a_balance, a_type, a_open_ts)
SELECT *
-- CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER),
-- column2,
-- column3,
-- column4 * column4, -- GEOMETRIC SHIFT
-- SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
-- DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789
SELECT
CAST(CAST(column1 as TEXT) || CAST(column1 as TEXT) AS INTEGER),
column2,
column3,
column4 * column4, -- GEOMETRIC SHIFT
SUBSTRING(column5, 2) || SUBSTRING(column5, 1, 1), -- FIRST CHAR TRANSPOSE
DATETIME(column6, '-123456789 seconds') -- SECOND SHIFT: 123456789
FROM (
VALUES
-- Customer 1 (alice johnson, OR) - 3 accounts
Expand Down Expand Up @@ -189,12 +189,11 @@ VALUES

INSERT INTO TRANSACTIONS (t_key, t_sourceaccount, t_destaccount, t_amount, t_ts)
SELECT
*
-- column1,
-- column2,
-- column3,
-- 1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67
-- DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321
column1,
column2,
column3,
1025.67 - column4, -- ARITHMETIC SHIFT: 1025.67
DATETIME(column5, '-54321 seconds') -- SECOND SHIFT: 54321
FROM (
VALUES
(1, 41, 8, 2753.92, '2019-11-11 18:00:52'),
Expand Down
Loading