Skip to content

Commit ccff7af

Browse files
Adding PyDough client & Mock Server (#431)
Co-authored-by: juankx-bodo <[email protected]>
1 parent ea8ce44 commit ccff7af

File tree

10 files changed

+950
-0
lines changed

10 files changed

+950
-0
lines changed

pydough/mask_server/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
Mask server API client.
3+
"""
4+
5+
__all__ = [
6+
"MaskServerInfo",
7+
"MaskServerInput",
8+
"MaskServerOutput",
9+
"MaskServerResponse",
10+
"RequestMethod",
11+
"ServerConnection",
12+
"ServerRequest",
13+
]
14+
15+
from .mask_server import (
16+
MaskServerInfo,
17+
MaskServerInput,
18+
MaskServerOutput,
19+
MaskServerResponse,
20+
)
21+
from .server_connection import (
22+
RequestMethod,
23+
ServerConnection,
24+
ServerRequest,
25+
)

pydough/mask_server/mask_server.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""
2+
Interface for the mask server. This API includes the MaskServerInfo class and related
3+
data structures including the MaskServerInput and MaskServerOutput dataclasses.
4+
"""
5+
6+
__all__ = [
7+
"MaskServerInfo",
8+
"MaskServerInput",
9+
"MaskServerOutput",
10+
"MaskServerResponse",
11+
]
12+
13+
from dataclasses import dataclass
14+
from enum import Enum
15+
from typing import Any
16+
17+
from pydough.mask_server.server_connection import (
18+
RequestMethod,
19+
ServerConnection,
20+
ServerRequest,
21+
)
22+
23+
24+
class MaskServerResponse(Enum):
25+
"""
26+
Enum to represent the type of response from the MaskServer.
27+
"""
28+
29+
IN_ARRAY = "IN_ARRAY"
30+
"""
31+
The mask server returned an "IN" response.
32+
"""
33+
34+
NOT_IN_ARRAY = "NOT_IN_ARRAY"
35+
"""
36+
The mask server returned an "NOT_IN" response.
37+
"""
38+
39+
UNSUPPORTED = "UNSUPPORTED"
40+
"""
41+
The mask server returned an "UNSUPPORTED" response. Or the response is not
42+
one of the supported cases.
43+
"""
44+
45+
46+
@dataclass
47+
class MaskServerInput:
48+
"""
49+
Input data structure for the MaskServer.
50+
"""
51+
52+
table_path: str
53+
"""
54+
The fully qualified SQL table path, given from the metadata.
55+
"""
56+
57+
column_name: str
58+
"""
59+
The SQL column name, given from the metadata.
60+
"""
61+
62+
expression: list[str | int | float | None | bool]
63+
"""
64+
The linear serialization of the predicate expression.
65+
"""
66+
67+
68+
@dataclass
69+
class MaskServerOutput:
70+
"""
71+
Output data structure for the MaskServer.
72+
73+
If the server returns an unsupported value, it returns an output with
74+
UNSUPPORTED + a None payload.
75+
"""
76+
77+
response_case: MaskServerResponse
78+
"""
79+
The type of response from the server.
80+
"""
81+
82+
payload: Any
83+
"""
84+
The payload of the response. This can be the result of the predicate evaluation
85+
or None if an error occurred.
86+
"""
87+
88+
89+
class MaskServerInfo:
90+
"""
91+
The MaskServeraInfo class is responsible for evaluating predicates against a
92+
given table and column. It interacts with an external mask server to
93+
perform the evaluation.
94+
"""
95+
96+
def __init__(self, base_url: str, token: str | None = None):
97+
"""
98+
Initialize the MaskServerInfo with the given server URL.
99+
100+
Args:
101+
`base_url`: The URL of the mask server.
102+
`token`: Optional authentication token for the server.
103+
"""
104+
self.connection: ServerConnection = ServerConnection(
105+
base_url=base_url, token=token
106+
)
107+
108+
def get_server_response_case(self, server_case: str) -> MaskServerResponse:
109+
"""
110+
Mapping from server response strings to MaskServerResponse enum values.
111+
112+
Args:
113+
`server_case`: The response string from the server.
114+
Returns:
115+
The corresponding MaskServerResponse enum value.
116+
"""
117+
match server_case:
118+
case "IN":
119+
return MaskServerResponse.IN_ARRAY
120+
case "NOT_IN":
121+
return MaskServerResponse.NOT_IN_ARRAY
122+
case _:
123+
return MaskServerResponse.UNSUPPORTED
124+
125+
def simplify_simple_expression_batch(
126+
self, batch: list[MaskServerInput]
127+
) -> list[MaskServerOutput]:
128+
"""
129+
Sends a batch of predicate expressions to the mask server for evaluation.
130+
131+
Each input in the batch specifies a table, column, and predicate
132+
expression.The method constructs a request, sends it to the server,
133+
and parses the response into a list of MaskServerOutput objects, each
134+
indicating the server's decision for the corresponding input.
135+
136+
Args:
137+
`batch`: The list of inputs to be sent to the server.
138+
139+
Returns:
140+
An output list containing the response case and payload.
141+
"""
142+
assert batch != [], "Batch cannot be empty."
143+
144+
path: str = "v1/predicates/batch-evaluate"
145+
method: RequestMethod = RequestMethod.POST
146+
147+
request: ServerRequest = self.generate_request(batch, path, method)
148+
149+
response_json = self.connection.send_server_request(request)
150+
result: list[MaskServerOutput] = self.generate_result(response_json)
151+
152+
return result
153+
154+
def generate_request(
155+
self, batch: list[MaskServerInput], path: str, method: RequestMethod
156+
) -> ServerRequest:
157+
"""
158+
Generate a server request from the given batch of server inputs and path.
159+
160+
Args:
161+
`batch`: A list of MaskServerInput objects.
162+
`path`: The API endpoint path.
163+
164+
Returns:
165+
A server request including payload to be sent.
166+
167+
Example payload:
168+
{
169+
"items": [
170+
{
171+
"column_reference": "srv.db.tbl.col",
172+
"predicate": ["EQUAL", 2, "__col__", 1],
173+
"mode": "dynamic",
174+
"dry_run": false
175+
},
176+
...
177+
],
178+
"expression_format": {"name": "linear", "version": "0.2.0"}
179+
}
180+
"""
181+
182+
payload: dict = {
183+
"items": [],
184+
"expression_format": {"name": "linear", "version": "0.2.0"},
185+
}
186+
187+
for item in batch:
188+
evaluate_request: dict = {
189+
"column_reference": f"{item.table_path}.{item.column_name}",
190+
"predicate": item.expression,
191+
"mode": "dynamic",
192+
"dry_run": False,
193+
}
194+
payload["items"].append(evaluate_request)
195+
196+
return ServerRequest(path=path, payload=payload, method=method)
197+
198+
def generate_result(self, response: dict) -> list[MaskServerOutput]:
199+
"""
200+
Generate a list of server outputs from the server response.
201+
202+
Args:
203+
`response`: The response from the mask server.
204+
205+
Returns:
206+
A list of server outputs objects.
207+
208+
Example response:
209+
{
210+
"result": "SUCCESS",
211+
"items": [
212+
{
213+
"index": 0,
214+
"result": "SUCCESS",
215+
"decision": {"strategy": "values", "reason": "mock"},
216+
"predicate_hash": "hash0",
217+
"encryption_mode": "clear",
218+
"materialization": {
219+
"type": "literal",
220+
"operator": "IN",
221+
"values": [0],
222+
"count": 1
223+
}
224+
},
225+
...
226+
]
227+
}
228+
"""
229+
result: list[MaskServerOutput] = []
230+
231+
for item in response.get("items", []):
232+
"""
233+
Case on whether operator is ERROR or not
234+
If ERROR, then response_case is unsupported and payload is None
235+
Otherwise, call self.get_server_response(operator) to get the enum, store in a variable, then case on this variable to obtain the payload (use item.get("materialization", {}).get("values", []) if it is IN_ARRAY or NOT_IN_ARRAY, otherwise None)
236+
"""
237+
if item.get("result") == "ERROR":
238+
result.append(
239+
MaskServerOutput(
240+
response_case=MaskServerResponse.UNSUPPORTED,
241+
payload=None,
242+
)
243+
)
244+
else:
245+
materialization: dict = item.get("materialization", {})
246+
response_case: MaskServerResponse = self.get_server_response_case(
247+
materialization.get("operator", "ERROR")
248+
)
249+
250+
payload: Any = None
251+
252+
if response_case in (
253+
MaskServerResponse.IN_ARRAY,
254+
MaskServerResponse.NOT_IN_ARRAY,
255+
):
256+
payload = materialization.get("values", [])
257+
258+
result.append(
259+
MaskServerOutput(
260+
response_case=response_case,
261+
payload=payload,
262+
)
263+
)
264+
265+
return result

0 commit comments

Comments
 (0)