Skip to content

Commit ceefe6a

Browse files
committed
Allow parametrized generalization expressions
1 parent 42c0a2f commit ceefe6a

File tree

8 files changed

+148
-39
lines changed

8 files changed

+148
-39
lines changed

pg_diffix/query/anonymization.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ typedef struct AnonQueryLinks AnonQueryLinks;
1616
* Transforms subqueries accessing personal relations into anonymizing subqueries.
1717
* Returned data is used during plan rewrite.
1818
*/
19-
extern AnonQueryLinks *compile_query(Query *query, List *personal_relations);
19+
extern AnonQueryLinks *compile_query(Query *query, List *personal_relations, ParamListInfo bound_params);
2020

2121
/*
2222
* Calls `rewrite_plan` for each item in a list of Plan nodes.

pg_diffix/query/validation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ extern void verify_anonymization_requirements(Query *query);
2727
* Verifies restrictions on bucket expressions, operates on an anonymizing query.
2828
* If requirements are not met, an error is reported and execution is halted.
2929
*/
30-
extern void verify_bucket_expressions(Query *query);
30+
extern void verify_bucket_expressions(Query *query, ParamListInfo bound_params);
3131

3232
/*
3333
* Returns `true` if the given list of `RangeTblEntry` from `ExecutorCheckPerms` does not access `pg_catalog`

pg_diffix/utils.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef PG_DIFFIX_UTILS_H
22
#define PG_DIFFIX_UTILS_H
33

4+
#include "nodes/params.h"
45
#include "nodes/pg_list.h"
6+
#include "nodes/primnodes.h"
57
#include "utils/datum.h"
68

79
/*-------------------------------------------------------------------------
@@ -119,4 +121,68 @@ static inline List *hash_set_union(List *dst_set, const List *src_set)
119121

120122
#endif
121123

124+
/*-------------------------------------------------------------------------
125+
* Node utils
126+
*-------------------------------------------------------------------------
127+
*/
128+
129+
static inline ParamExternData *get_param_data(ParamListInfo bound_params, int one_based_paramid)
130+
{
131+
#if PG_MAJORVERSION_NUM == 13
132+
int paramid = one_based_paramid;
133+
#else
134+
int paramid = one_based_paramid - 1;
135+
#endif
136+
if (bound_params->paramFetch != NULL)
137+
return bound_params->paramFetch(bound_params, paramid, true, NULL);
138+
else
139+
return &bound_params->params[paramid];
140+
}
141+
142+
static inline bool is_simple_constant(Node *node)
143+
{
144+
return IsA(node, Const) || (IsA(node, Param) && ((Param *)node)->paramkind == PARAM_EXTERN);
145+
}
146+
147+
static inline void get_simple_constant_typed_value(Node *node, ParamListInfo bound_params, Oid *type, Datum *value, bool *isnull)
148+
{
149+
if (IsA(node, Const))
150+
{
151+
Const *const_expr = (Const *)node;
152+
*type = const_expr->consttype;
153+
*value = const_expr->constvalue;
154+
*isnull = const_expr->constisnull;
155+
}
156+
else if (IsA(node, Param) && ((Param *)node)->paramkind == PARAM_EXTERN)
157+
{
158+
Param *param_expr = (Param *)node;
159+
ParamExternData *param_data = get_param_data(bound_params, param_expr->paramid);
160+
*type = param_data->ptype;
161+
*value = param_data->value;
162+
*isnull = param_data->isnull;
163+
}
164+
else
165+
{
166+
FAILWITH("Attempted to get simple constant value of non-Const, non-PARAM_EXTERN node");
167+
}
168+
}
169+
170+
static inline int get_simple_constant_location(Node *node)
171+
{
172+
if (IsA(node, Const))
173+
{
174+
Const *const_expr = (Const *)node;
175+
return const_expr->location;
176+
}
177+
else if (IsA(node, Param) && ((Param *)node)->paramkind == PARAM_EXTERN)
178+
{
179+
Param *param_expr = (Param *)node;
180+
return param_expr->location;
181+
}
182+
else
183+
{
184+
FAILWITH("Attempted to get simple constant value of non-Const, non-PARAM_EXTERN node");
185+
}
186+
}
187+
122188
#endif /* PG_DIFFIX_UTILS_H */

src/hooks.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ static void pg_diffix_post_parse_analyze(ParseState *pstate, Query *query, Jumbl
4747
}
4848
#endif
4949

50-
static AnonQueryLinks *prepare_query(Query *query)
50+
static AnonQueryLinks *prepare_query(Query *query, ParamListInfo bound_params)
5151
{
5252
/* Do nothing for sessions with direct access. */
5353
if (get_session_access_level() == ACCESS_DIRECT)
@@ -70,7 +70,7 @@ static AnonQueryLinks *prepare_query(Query *query)
7070
*/
7171
config_validate();
7272

73-
AnonQueryLinks *links = compile_query(query, personal_relations);
73+
AnonQueryLinks *links = compile_query(query, personal_relations, bound_params);
7474

7575
DEBUG_LOG("Compiled query (Query ID=%lu) (User ID=%u) %s", query->queryId, GetSessionUserId(), nodeToString(query));
7676

@@ -88,7 +88,7 @@ static PlannedStmt *pg_diffix_planner(
8888

8989
DEBUG_LOG("Statement (Query ID=%lu) (User ID=%u): %s", query->queryId, GetSessionUserId(), query_string);
9090

91-
AnonQueryLinks *links = prepare_query(query);
91+
AnonQueryLinks *links = prepare_query(query, boundParams);
9292

9393
planner_hook_type planner = (prev_planner_hook ? prev_planner_hook : standard_planner);
9494
PlannedStmt *plan = planner(query, query_string, cursorOptions, boundParams);

src/query/anonymization.c

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ static void append_seed_material(
316316
typedef struct CollectMaterialContext
317317
{
318318
Query *query;
319+
ParamListInfo bound_params;
319320
char material[MAX_SEED_MATERIAL_SIZE];
320321
} CollectMaterialContext;
321322

@@ -356,14 +357,17 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
356357
append_seed_material(context->material, attribute_name, '.');
357358
}
358359

359-
if (IsA(node, Const))
360+
if (is_simple_constant(node))
360361
{
361-
Const *const_expr = (Const *)node;
362+
Oid type;
363+
Datum value;
364+
bool isnull;
365+
get_simple_constant_typed_value(node, context->bound_params, &type, &value, &isnull);
362366

363-
if (!is_supported_numeric_type(const_expr->consttype))
364-
FAILWITH_LOCATION(const_expr->location, "Unsupported constant type used in bucket definition!");
367+
if (!is_supported_numeric_type(type))
368+
FAILWITH_LOCATION(get_simple_constant_location(node), "Unsupported constant type used in bucket definition!");
365369

366-
double const_as_double = numeric_value_to_double(const_expr->consttype, const_expr->constvalue);
370+
double const_as_double = numeric_value_to_double(type, value);
367371
char const_as_string[DOUBLE_SHORTEST_DECIMAL_LEN];
368372
double_to_shortest_decimal_buf(const_as_double, const_as_string);
369373
append_seed_material(context->material, const_as_string, ',');
@@ -377,7 +381,7 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
377381
* Computes the SQL part of the bucket seed by combining the unique grouping expressions' seed material hashes.
378382
* Grouping clause (if any) must be made explicit before calling this.
379383
*/
380-
static seed_t prepare_bucket_seeds(Query *query)
384+
static seed_t prepare_bucket_seeds(Query *query, ParamListInfo bound_params)
381385
{
382386
List *seed_material_hash_set = NULL;
383387

@@ -388,7 +392,7 @@ static seed_t prepare_bucket_seeds(Query *query)
388392
Node *expr = lfirst(cell);
389393

390394
/* Start from empty string and append material pieces for each non-cast expression. */
391-
CollectMaterialContext collect_context = {.query = query, .material = ""};
395+
CollectMaterialContext collect_context = {.query = query, .bound_params = bound_params, .material = ""};
392396
collect_seed_material(expr, &collect_context);
393397

394398
/* Keep materials with unique hashes to avoid them cancelling each other. */
@@ -578,15 +582,15 @@ static void wrap_having_qual(Query *query)
578582
COERCE_EXPLICIT_CALL);
579583
}
580584

581-
static void compile_anonymizing_query(Query *query, List *personal_relations, AnonQueryLinks *anon_links)
585+
static void compile_anonymizing_query(Query *query, List *personal_relations, AnonQueryLinks *anon_links, ParamListInfo bound_params)
582586
{
583587
verify_anonymization_requirements(query);
584588

585589
AnonymizationContext *anon_context = make_query_anonymizing(query, personal_relations);
586590

587-
verify_bucket_expressions(query);
591+
verify_bucket_expressions(query, bound_params);
588592

589-
anon_context->sql_seed = prepare_bucket_seeds(query);
593+
anon_context->sql_seed = prepare_bucket_seeds(query, bound_params);
590594

591595
link_anon_context(query, anon_links, anon_context);
592596

@@ -613,6 +617,7 @@ typedef struct QueryCompileContext
613617
{
614618
List *personal_relations;
615619
AnonQueryLinks *anon_links;
620+
ParamListInfo bound_params;
616621
} QueryCompileContext;
617622

618623
static bool compile_query_walker(Node *node, QueryCompileContext *context)
@@ -651,19 +656,20 @@ static bool compile_query_walker(Node *node, QueryCompileContext *context)
651656
{
652657
Query *query = (Query *)node;
653658
if (is_anonymizing_query(query, context->personal_relations))
654-
compile_anonymizing_query(query, context->personal_relations, context->anon_links);
659+
compile_anonymizing_query(query, context->personal_relations, context->anon_links, context->bound_params);
655660
else
656661
query_tree_walker(query, compile_query_walker, context, QTW_EXAMINE_RTES_AFTER);
657662
}
658663

659664
return expression_tree_walker(node, compile_query_walker, context);
660665
}
661666

662-
AnonQueryLinks *compile_query(Query *query, List *personal_relations)
667+
AnonQueryLinks *compile_query(Query *query, List *personal_relations, ParamListInfo bound_params)
663668
{
664669
QueryCompileContext context = {
665670
.personal_relations = personal_relations,
666671
.anon_links = palloc0(sizeof(AnonQueryLinks)),
672+
.bound_params = bound_params,
667673
};
668674

669675
compile_query_walker((Node *)query, &context);

src/query/validation.c

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ void verify_utility_command(Node *utility_stmt)
5454
case T_DeallocateStmt:
5555
case T_FetchStmt:
5656
case T_ClosePortalStmt:
57+
case T_PrepareStmt:
58+
case T_ExecuteStmt:
5759
break;
5860
default:
5961
FAILWITH("Statement requires direct access level.");
@@ -225,7 +227,7 @@ static void verify_bucket_expression(Node *node)
225227

226228
for (int i = 1; i < list_length(func_expr->args); i++)
227229
{
228-
if (!IsA(unwrap_cast((Node *)list_nth(func_expr->args, i)), Const))
230+
if (!is_simple_constant(unwrap_cast((Node *)list_nth(func_expr->args, i))))
229231
FAILWITH_LOCATION(func_expr->location, "Non-primary arguments for a bucket function have to be simple constants.");
230232
}
231233
}
@@ -234,10 +236,9 @@ static void verify_bucket_expression(Node *node)
234236
OpExpr *op_expr = (OpExpr *)node;
235237
FAILWITH_LOCATION(op_expr->location, "Use of operators to define buckets is not supported.");
236238
}
237-
else if (IsA(node, Const))
239+
else if (is_simple_constant(node))
238240
{
239-
Const *const_expr = (Const *)node;
240-
FAILWITH_LOCATION(const_expr->location, "Simple constants are not allowed as bucket expressions.");
241+
FAILWITH_LOCATION(get_simple_constant_location(node), "Simple constants are not allowed as bucket expressions.");
241242
}
242243
else if (IsA(node, RelabelType))
243244
{
@@ -262,14 +263,17 @@ static void verify_bucket_expression(Node *node)
262263
}
263264
}
264265

265-
static void verify_substring(FuncExpr *func_expr)
266+
static void verify_substring(FuncExpr *func_expr, ParamListInfo bound_params)
266267
{
267268
Node *node = unwrap_cast(list_nth(func_expr->args, 1));
268-
Assert(IsA(node, Const)); /* Checked by prior validations */
269-
Const *second_arg = (Const *)node;
270-
271-
if (DatumGetUInt32(second_arg->constvalue) != 1)
272-
FAILWITH_LOCATION(second_arg->location, "Generalization used in the query is not allowed in untrusted access level.");
269+
Assert(is_simple_constant(node)); /* Checked by prior validations */
270+
Oid type;
271+
Datum value;
272+
bool isnull;
273+
get_simple_constant_typed_value(node, bound_params, &type, &value, &isnull);
274+
275+
if (DatumGetUInt32(value) != 1)
276+
FAILWITH_LOCATION(get_simple_constant_location(node), "Generalization used in the query is not allowed in untrusted access level.");
273277
}
274278

275279
/* money-style numbers, i.e. 1, 2, or 5 preceeded by or followed by zeros: ⟨... 0.1, 0.2, 0.5, 1, 2, 5, 10, ...⟩ */
@@ -286,29 +290,32 @@ static bool is_money_style(double number)
286290
}
287291

288292
/* Expects the expression being the second argument to `round_by` et al. */
289-
static void verify_bin_size(Node *range_expr)
293+
static void verify_bin_size(Node *range_expr, ParamListInfo bound_params)
290294
{
291295
Node *range_node = unwrap_cast(range_expr);
292-
Assert(IsA(range_node, Const)); /* Checked by prior validations */
293-
Const *range_const = (Const *)range_node;
296+
Assert(is_simple_constant(range_node)); /* Checked by prior validations */
297+
Oid type;
298+
Datum value;
299+
bool isnull;
300+
get_simple_constant_typed_value(range_node, bound_params, &type, &value, &isnull);
294301

295-
if (!is_supported_numeric_type(range_const->consttype))
296-
FAILWITH_LOCATION(range_const->location, "Unsupported constant type used in generalization.");
302+
if (!is_supported_numeric_type(type))
303+
FAILWITH_LOCATION(get_simple_constant_location(range_node), "Unsupported constant type used in generalization.");
297304

298-
if (!is_money_style(numeric_value_to_double(range_const->consttype, range_const->constvalue)))
299-
FAILWITH_LOCATION(range_const->location, "Generalization used in the query is not allowed in untrusted access level.");
305+
if (!is_money_style(numeric_value_to_double(type, value)))
306+
FAILWITH_LOCATION(get_simple_constant_location(range_node), "Generalization used in the query is not allowed in untrusted access level.");
300307
}
301308

302-
static void verify_generalization(Node *node)
309+
static void verify_generalization(Node *node, ParamListInfo bound_params)
303310
{
304311
if (IsA(node, FuncExpr))
305312
{
306313
FuncExpr *func_expr = (FuncExpr *)node;
307314

308315
if (is_substring_builtin(func_expr->funcid))
309-
verify_substring(func_expr);
316+
verify_substring(func_expr, bound_params);
310317
else if (is_implicit_range_udf_untrusted(func_expr->funcid))
311-
verify_bin_size((Node *)list_nth(func_expr->args, 1));
318+
verify_bin_size((Node *)list_nth(func_expr->args, 1), bound_params);
312319
else if (is_implicit_range_builtin_untrusted(func_expr->funcid))
313320
;
314321
else
@@ -317,7 +324,7 @@ static void verify_generalization(Node *node)
317324
}
318325

319326
/* Should be run on anonymizing queries only. */
320-
void verify_bucket_expressions(Query *query)
327+
void verify_bucket_expressions(Query *query, ParamListInfo bound_params)
321328
{
322329
AccessLevel access_level = get_session_access_level();
323330

@@ -333,7 +340,7 @@ void verify_bucket_expressions(Query *query)
333340
Node *expr = (Node *)lfirst(cell);
334341
verify_bucket_expression(expr);
335342
if (access_level == ACCESS_ANONYMIZED_UNTRUSTED)
336-
verify_generalization(expr);
343+
verify_generalization(expr, bound_params);
337344
}
338345
}
339346

test/expected/validation.out

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ SELECT city, 'aaaa' FROM test_validation GROUP BY 1, 2;
361361
ERROR: [PG_DIFFIX] Simple constants are not allowed as bucket expressions.
362362
LINE 1: SELECT city, 'aaaa' FROM test_validation GROUP BY 1, 2;
363363
^
364+
PREPARE prepared_param_as_label(text) AS SELECT city, $1 FROM test_validation GROUP BY 1, 2;
365+
EXECUTE prepared_param_as_label('aaaa');
366+
ERROR: [PG_DIFFIX] Simple constants are not allowed as bucket expressions.
364367
SELECT COUNT(*) FROM test_validation GROUP BY round(floor(id));
365368
ERROR: [PG_DIFFIX] Primary argument for a bucket function has to be a simple column reference.
366369
LINE 1: SELECT COUNT(*) FROM test_validation GROUP BY round(floor(id...
@@ -586,3 +589,20 @@ SELECT diffix.ceil_by(discount, 2) from test_validation;
586589
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.
587590
LINE 1: SELECT diffix.ceil_by(discount, 2) from test_validation;
588591
^
592+
-- Allow prepared statements with generalization constants as params, and validate them
593+
PREPARE prepared_floor_by(numeric) AS SELECT diffix.floor_by(discount, $1) FROM test_validation GROUP BY 1;
594+
EXECUTE prepared_floor_by(2.0);
595+
floor_by
596+
----------
597+
(0 rows)
598+
599+
EXECUTE prepared_floor_by(2.1);
600+
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.
601+
PREPARE prepared_substring(int, int) AS SELECT substring(city, $1, $2) FROM test_validation GROUP BY 1;
602+
EXECUTE prepared_substring(1, 2);
603+
substring
604+
-----------
605+
(0 rows)
606+
607+
EXECUTE prepared_substring(2, 3);
608+
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.

test/sql/validation.sql

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ SELECT COUNT(*) FROM test_validation GROUP BY LENGTH(city);
187187
SELECT COUNT(*) FROM test_validation GROUP BY city || 'xxx';
188188
SELECT LENGTH(city) FROM test_validation;
189189
SELECT city, 'aaaa' FROM test_validation GROUP BY 1, 2;
190+
PREPARE prepared_param_as_label(text) AS SELECT city, $1 FROM test_validation GROUP BY 1, 2;
191+
EXECUTE prepared_param_as_label('aaaa');
190192
SELECT COUNT(*) FROM test_validation GROUP BY round(floor(id));
191193
SELECT COUNT(*) FROM test_validation GROUP BY floor(cast(discount AS integer));
192194
SELECT COUNT(*) FROM test_validation GROUP BY substr(city, 1, id);
@@ -280,3 +282,11 @@ SELECT diffix.floor_by(discount, 5000000000.1) from test_validation;
280282
SELECT width_bucket(discount, 2, 200, 5) from test_validation;
281283
SELECT ceil(discount) from test_validation;
282284
SELECT diffix.ceil_by(discount, 2) from test_validation;
285+
286+
-- Allow prepared statements with generalization constants as params, and validate them
287+
PREPARE prepared_floor_by(numeric) AS SELECT diffix.floor_by(discount, $1) FROM test_validation GROUP BY 1;
288+
EXECUTE prepared_floor_by(2.0);
289+
EXECUTE prepared_floor_by(2.1);
290+
PREPARE prepared_substring(int, int) AS SELECT substring(city, $1, $2) FROM test_validation GROUP BY 1;
291+
EXECUTE prepared_substring(1, 2);
292+
EXECUTE prepared_substring(2, 3);

0 commit comments

Comments
 (0)