Skip to content

Commit 9c3a9cd

Browse files
authored
Merge pull request #406 from diffix/piotr/support-params-elm
Allow parametrized generalization expressions
2 parents e80721e + abd28c7 commit 9c3a9cd

File tree

11 files changed

+160
-39
lines changed

11 files changed

+160
-39
lines changed

pg_diffix/node_funcs.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef PG_DIFFIX_NODE_FUNCS_H
2+
#define PG_DIFFIX_NODE_FUNCS_H
3+
4+
#include "nodes/params.h"
5+
#include "nodes/primnodes.h"
6+
7+
/*
8+
* Returns `true` if the node represents a constant from pg_diffix perspective, i.e. `1` or `$1`.
9+
*/
10+
extern bool is_stable_expression(Node *node);
11+
12+
/*
13+
* Fills `type`, `value` and `isnull` with what the stable expression `node` holds.
14+
* `bound_params` must be provided, since `node` might be a `Param` node.
15+
*/
16+
extern void get_stable_expression_value(Node *node, ParamListInfo bound_params, Oid *type, Datum *value, bool *isnull);
17+
18+
#endif /* PG_DIFFIX_NODE_FUNCS_H */

pg_diffix/query/anonymization.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "pg_diffix/aggregation/common.h"
88
#include "pg_diffix/aggregation/noise.h"
9+
#include "pg_diffix/node_funcs.h"
910

1011
/*
1112
* Opaque struct containing references to anonymizing (sub)queries.
@@ -16,7 +17,7 @@ typedef struct AnonQueryLinks AnonQueryLinks;
1617
* Transforms subqueries accessing personal relations into anonymizing subqueries.
1718
* Returned data is used during plan rewrite.
1819
*/
19-
extern AnonQueryLinks *compile_query(Query *query, List *personal_relations);
20+
extern AnonQueryLinks *compile_query(Query *query, List *personal_relations, ParamListInfo bound_params);
2021

2122
/*
2223
* Calls `rewrite_plan` for each item in a list of Plan nodes.

pg_diffix/query/validation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ extern void verify_anonymization_requirements(Query *query);
2727
* Verifies restrictions on bucket expressions, operates on an anonymizing query.
2828
* If requirements are not met, an error is reported and execution is halted.
2929
*/
30-
extern void verify_bucket_expressions(Query *query);
30+
extern void verify_bucket_expressions(Query *query, ParamListInfo bound_params);
3131

3232
/*
3333
* Returns `true` if the given list of `RangeTblEntry` from `ExecutorCheckPerms` does not access `pg_catalog`

src/hooks.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ static void pg_diffix_post_parse_analyze(ParseState *pstate, Query *query, Jumbl
4747
}
4848
#endif
4949

50-
static AnonQueryLinks *prepare_query(Query *query)
50+
static AnonQueryLinks *prepare_query(Query *query, ParamListInfo bound_params)
5151
{
5252
/* Do nothing for sessions with direct access. */
5353
if (get_session_access_level() == ACCESS_DIRECT)
@@ -70,7 +70,7 @@ static AnonQueryLinks *prepare_query(Query *query)
7070
*/
7171
config_validate();
7272

73-
AnonQueryLinks *links = compile_query(query, personal_relations);
73+
AnonQueryLinks *links = compile_query(query, personal_relations, bound_params);
7474

7575
DEBUG_LOG("Compiled query (Query ID=%lu) (User ID=%u) %s", query->queryId, GetSessionUserId(), nodeToString(query));
7676

@@ -88,7 +88,7 @@ static PlannedStmt *pg_diffix_planner(
8888

8989
DEBUG_LOG("Statement (Query ID=%lu) (User ID=%u): %s", query->queryId, GetSessionUserId(), query_string);
9090

91-
AnonQueryLinks *links = prepare_query(query);
91+
AnonQueryLinks *links = prepare_query(query, boundParams);
9292

9393
planner_hook_type planner = (prev_planner_hook ? prev_planner_hook : standard_planner);
9494
PlannedStmt *plan = planner(query, query_string, cursorOptions, boundParams);

src/node_funcs.c

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include "postgres.h"
2+
3+
#include "pg_diffix/node_funcs.h"
4+
#include "pg_diffix/utils.h"
5+
6+
static ParamExternData *get_param_data(ParamListInfo bound_params, int one_based_paramid)
7+
{
8+
if (bound_params->paramFetch != NULL)
9+
return bound_params->paramFetch(bound_params, one_based_paramid - 1, true, NULL);
10+
else
11+
return &bound_params->params[one_based_paramid - 1];
12+
}
13+
14+
bool is_stable_expression(Node *node)
15+
{
16+
return IsA(node, Const) || (IsA(node, Param) && ((Param *)node)->paramkind == PARAM_EXTERN);
17+
}
18+
19+
void get_stable_expression_value(Node *node, ParamListInfo bound_params, Oid *type, Datum *value, bool *isnull)
20+
{
21+
if (IsA(node, Const))
22+
{
23+
Const *const_expr = (Const *)node;
24+
*type = const_expr->consttype;
25+
*value = const_expr->constvalue;
26+
*isnull = const_expr->constisnull;
27+
}
28+
else if (IsA(node, Param) && ((Param *)node)->paramkind == PARAM_EXTERN)
29+
{
30+
Param *param_expr = (Param *)node;
31+
ParamExternData *param_data = get_param_data(bound_params, param_expr->paramid);
32+
*type = param_data->ptype;
33+
*value = param_data->value;
34+
*isnull = param_data->isnull;
35+
}
36+
else
37+
{
38+
FAILWITH("Attempted to get simple constant value of non-Const, non-PARAM_EXTERN node");
39+
}
40+
}

src/query/anonymization.c

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ static void append_seed_material(
337337
typedef struct CollectMaterialContext
338338
{
339339
Query *query;
340+
ParamListInfo bound_params;
340341
char material[MAX_SEED_MATERIAL_SIZE];
341342
} CollectMaterialContext;
342343

@@ -377,14 +378,17 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
377378
append_seed_material(context->material, attribute_name, '.');
378379
}
379380

380-
if (IsA(node, Const))
381+
if (is_stable_expression(node))
381382
{
382-
Const *const_expr = (Const *)node;
383+
Oid type;
384+
Datum value;
385+
bool isnull;
386+
get_stable_expression_value(node, context->bound_params, &type, &value, &isnull);
383387

384-
if (!is_supported_numeric_type(const_expr->consttype))
385-
FAILWITH_LOCATION(const_expr->location, "Unsupported constant type used in bucket definition!");
388+
if (!is_supported_numeric_type(type))
389+
FAILWITH_LOCATION(exprLocation(node), "Unsupported constant type used in bucket definition!");
386390

387-
double const_as_double = numeric_value_to_double(const_expr->consttype, const_expr->constvalue);
391+
double const_as_double = numeric_value_to_double(type, value);
388392
char const_as_string[DOUBLE_SHORTEST_DECIMAL_LEN];
389393
double_to_shortest_decimal_buf(const_as_double, const_as_string);
390394
append_seed_material(context->material, const_as_string, ',');
@@ -398,7 +402,7 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
398402
* Computes the SQL part of the bucket seed by combining the unique grouping expressions' seed material hashes.
399403
* Grouping clause (if any) must be made explicit before calling this.
400404
*/
401-
static seed_t prepare_bucket_seeds(Query *query)
405+
static seed_t prepare_bucket_seeds(Query *query, ParamListInfo bound_params)
402406
{
403407
List *seed_material_hash_set = NULL;
404408

@@ -409,7 +413,7 @@ static seed_t prepare_bucket_seeds(Query *query)
409413
Node *expr = lfirst(cell);
410414

411415
/* Start from empty string and append material pieces for each non-cast expression. */
412-
CollectMaterialContext collect_context = {.query = query, .material = ""};
416+
CollectMaterialContext collect_context = {.query = query, .bound_params = bound_params, .material = ""};
413417
collect_seed_material(expr, &collect_context);
414418

415419
/* Keep materials with unique hashes to avoid them cancelling each other. */
@@ -599,17 +603,17 @@ static void wrap_having_qual(Query *query)
599603
COERCE_EXPLICIT_CALL);
600604
}
601605

602-
static void compile_anonymizing_query(Query *query, List *personal_relations, AnonQueryLinks *anon_links)
606+
static void compile_anonymizing_query(Query *query, List *personal_relations, AnonQueryLinks *anon_links, ParamListInfo bound_params)
603607
{
604608
verify_anonymization_requirements(query);
605609

606610
AnonymizationContext *anon_context = make_query_anonymizing(query, personal_relations);
607611

608612
reject_aid_grouping(query);
609613

610-
verify_bucket_expressions(query);
614+
verify_bucket_expressions(query, bound_params);
611615

612-
anon_context->sql_seed = prepare_bucket_seeds(query);
616+
anon_context->sql_seed = prepare_bucket_seeds(query, bound_params);
613617

614618
link_anon_context(query, anon_links, anon_context);
615619

@@ -636,6 +640,7 @@ typedef struct QueryCompileContext
636640
{
637641
List *personal_relations;
638642
AnonQueryLinks *anon_links;
643+
ParamListInfo bound_params;
639644
} QueryCompileContext;
640645

641646
static bool compile_query_walker(Node *node, QueryCompileContext *context)
@@ -674,19 +679,20 @@ static bool compile_query_walker(Node *node, QueryCompileContext *context)
674679
{
675680
Query *query = (Query *)node;
676681
if (is_anonymizing_query(query, context->personal_relations))
677-
compile_anonymizing_query(query, context->personal_relations, context->anon_links);
682+
compile_anonymizing_query(query, context->personal_relations, context->anon_links, context->bound_params);
678683
else
679684
query_tree_walker(query, compile_query_walker, context, QTW_EXAMINE_RTES_AFTER);
680685
}
681686

682687
return expression_tree_walker(node, compile_query_walker, context);
683688
}
684689

685-
AnonQueryLinks *compile_query(Query *query, List *personal_relations)
690+
AnonQueryLinks *compile_query(Query *query, List *personal_relations, ParamListInfo bound_params)
686691
{
687692
QueryCompileContext context = {
688693
.personal_relations = personal_relations,
689694
.anon_links = palloc0(sizeof(AnonQueryLinks)),
695+
.bound_params = bound_params,
690696
};
691697

692698
compile_query_walker((Node *)query, &context);

src/query/validation.c

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "utils/lsyscache.h"
1616

1717
#include "pg_diffix/auth.h"
18+
#include "pg_diffix/node_funcs.h"
1819
#include "pg_diffix/oid_cache.h"
1920
#include "pg_diffix/query/allowed_objects.h"
2021
#include "pg_diffix/query/validation.h"
@@ -54,6 +55,8 @@ void verify_utility_command(Node *utility_stmt)
5455
case T_DeallocateStmt:
5556
case T_FetchStmt:
5657
case T_ClosePortalStmt:
58+
case T_PrepareStmt:
59+
case T_ExecuteStmt:
5760
break;
5861
default:
5962
FAILWITH("Statement requires direct access level.");
@@ -225,7 +228,7 @@ static void verify_bucket_expression(Node *node)
225228

226229
for (int i = 1; i < list_length(func_expr->args); i++)
227230
{
228-
if (!IsA(unwrap_cast((Node *)list_nth(func_expr->args, i)), Const))
231+
if (!is_stable_expression(unwrap_cast((Node *)list_nth(func_expr->args, i))))
229232
FAILWITH_LOCATION(func_expr->location, "Non-primary arguments for a bucket function have to be simple constants.");
230233
}
231234
}
@@ -234,10 +237,9 @@ static void verify_bucket_expression(Node *node)
234237
OpExpr *op_expr = (OpExpr *)node;
235238
FAILWITH_LOCATION(op_expr->location, "Use of operators to define buckets is not supported.");
236239
}
237-
else if (IsA(node, Const))
240+
else if (is_stable_expression(node))
238241
{
239-
Const *const_expr = (Const *)node;
240-
FAILWITH_LOCATION(const_expr->location, "Simple constants are not allowed as bucket expressions.");
242+
FAILWITH_LOCATION(exprLocation(node), "Simple constants are not allowed as bucket expressions.");
241243
}
242244
else if (IsA(node, RelabelType))
243245
{
@@ -262,14 +264,17 @@ static void verify_bucket_expression(Node *node)
262264
}
263265
}
264266

265-
static void verify_substring(FuncExpr *func_expr)
267+
static void verify_substring(FuncExpr *func_expr, ParamListInfo bound_params)
266268
{
267269
Node *node = unwrap_cast(list_nth(func_expr->args, 1));
268-
Assert(IsA(node, Const)); /* Checked by prior validations */
269-
Const *second_arg = (Const *)node;
270-
271-
if (DatumGetUInt32(second_arg->constvalue) != 1)
272-
FAILWITH_LOCATION(second_arg->location, "Generalization used in the query is not allowed in untrusted access level.");
270+
Assert(is_stable_expression(node)); /* Checked by prior validations */
271+
Oid type;
272+
Datum value;
273+
bool isnull;
274+
get_stable_expression_value(node, bound_params, &type, &value, &isnull);
275+
276+
if (DatumGetUInt32(value) != 1)
277+
FAILWITH_LOCATION(exprLocation(node), "Generalization used in the query is not allowed in untrusted access level.");
273278
}
274279

275280
/* money-style numbers, i.e. 1, 2, or 5 preceeded by or followed by zeros: ⟨... 0.1, 0.2, 0.5, 1, 2, 5, 10, ...⟩ */
@@ -286,29 +291,32 @@ static bool is_money_style(double number)
286291
}
287292

288293
/* Expects the expression being the second argument to `round_by` et al. */
289-
static void verify_bin_size(Node *range_expr)
294+
static void verify_bin_size(Node *range_expr, ParamListInfo bound_params)
290295
{
291296
Node *range_node = unwrap_cast(range_expr);
292-
Assert(IsA(range_node, Const)); /* Checked by prior validations */
293-
Const *range_const = (Const *)range_node;
297+
Assert(is_stable_expression(range_node)); /* Checked by prior validations */
298+
Oid type;
299+
Datum value;
300+
bool isnull;
301+
get_stable_expression_value(range_node, bound_params, &type, &value, &isnull);
294302

295-
if (!is_supported_numeric_type(range_const->consttype))
296-
FAILWITH_LOCATION(range_const->location, "Unsupported constant type used in generalization.");
303+
if (!is_supported_numeric_type(type))
304+
FAILWITH_LOCATION(exprLocation(range_node), "Unsupported constant type used in generalization.");
297305

298-
if (!is_money_style(numeric_value_to_double(range_const->consttype, range_const->constvalue)))
299-
FAILWITH_LOCATION(range_const->location, "Generalization used in the query is not allowed in untrusted access level.");
306+
if (!is_money_style(numeric_value_to_double(type, value)))
307+
FAILWITH_LOCATION(exprLocation(range_node), "Generalization used in the query is not allowed in untrusted access level.");
300308
}
301309

302-
static void verify_generalization(Node *node)
310+
static void verify_generalization(Node *node, ParamListInfo bound_params)
303311
{
304312
if (IsA(node, FuncExpr))
305313
{
306314
FuncExpr *func_expr = (FuncExpr *)node;
307315

308316
if (is_substring_builtin(func_expr->funcid))
309-
verify_substring(func_expr);
317+
verify_substring(func_expr, bound_params);
310318
else if (is_implicit_range_udf_untrusted(func_expr->funcid))
311-
verify_bin_size((Node *)list_nth(func_expr->args, 1));
319+
verify_bin_size((Node *)list_nth(func_expr->args, 1), bound_params);
312320
else if (is_implicit_range_builtin_untrusted(func_expr->funcid))
313321
;
314322
else
@@ -317,7 +325,7 @@ static void verify_generalization(Node *node)
317325
}
318326

319327
/* Should be run on anonymizing queries only. */
320-
void verify_bucket_expressions(Query *query)
328+
void verify_bucket_expressions(Query *query, ParamListInfo bound_params)
321329
{
322330
AccessLevel access_level = get_session_access_level();
323331

@@ -333,7 +341,7 @@ void verify_bucket_expressions(Query *query)
333341
Node *expr = (Node *)lfirst(cell);
334342
verify_bucket_expression(expr);
335343
if (access_level == ACCESS_ANONYMIZED_UNTRUSTED)
336-
verify_generalization(expr);
344+
verify_generalization(expr, bound_params);
337345
}
338346
}
339347

test/expected/noiseless.out

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,14 @@ SELECT discount, COUNT(DISTINCT id) FROM test_customers GROUP BY 1;
231231
2 | 4
232232
(4 rows)
233233

234+
----------------------------------------------------------------
235+
-- Prepared statements
236+
----------------------------------------------------------------
237+
PREPARE prepared_floor_by(numeric) AS SELECT diffix.floor_by(discount, $1), count(*) FROM test_customers GROUP BY 1;
238+
EXECUTE prepared_floor_by(2.0);
239+
floor_by | count
240+
----------+-------
241+
0 | 13
242+
2 | 4
243+
(2 rows)
244+

test/expected/validation.out

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ SELECT city, 'aaaa' FROM test_validation GROUP BY 1, 2;
361361
ERROR: [PG_DIFFIX] Simple constants are not allowed as bucket expressions.
362362
LINE 1: SELECT city, 'aaaa' FROM test_validation GROUP BY 1, 2;
363363
^
364+
PREPARE prepared_param_as_label(text) AS SELECT city, $1 FROM test_validation GROUP BY 1, 2;
365+
EXECUTE prepared_param_as_label('aaaa');
366+
ERROR: [PG_DIFFIX] Simple constants are not allowed as bucket expressions.
364367
SELECT COUNT(*) FROM test_validation GROUP BY round(floor(id));
365368
ERROR: [PG_DIFFIX] Primary argument for a bucket function has to be a simple column reference.
366369
LINE 1: SELECT COUNT(*) FROM test_validation GROUP BY round(floor(id...
@@ -605,3 +608,20 @@ SELECT diffix.ceil_by(discount, 2) from test_validation;
605608
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.
606609
LINE 1: SELECT diffix.ceil_by(discount, 2) from test_validation;
607610
^
611+
-- Allow prepared statements with generalization constants as params, and validate them
612+
PREPARE prepared_floor_by(numeric) AS SELECT diffix.floor_by(discount, $1) FROM test_validation GROUP BY 1;
613+
EXECUTE prepared_floor_by(2.0);
614+
floor_by
615+
----------
616+
(0 rows)
617+
618+
EXECUTE prepared_floor_by(2.1);
619+
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.
620+
PREPARE prepared_substring(int, int) AS SELECT substring(city, $1, $2) FROM test_validation GROUP BY 1;
621+
EXECUTE prepared_substring(1, 2);
622+
substring
623+
-----------
624+
(0 rows)
625+
626+
EXECUTE prepared_substring(2, 3);
627+
ERROR: [PG_DIFFIX] Generalization used in the query is not allowed in untrusted access level.

test/sql/noiseless.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,10 @@ SELECT COUNT(DISTINCT planet) FROM test_customers;
7474
-- `low_count_min_threshold` for queries with GROUP BY
7575
SELECT city, COUNT(DISTINCT city) FROM test_customers GROUP BY 1;
7676
SELECT discount, COUNT(DISTINCT id) FROM test_customers GROUP BY 1;
77+
78+
----------------------------------------------------------------
79+
-- Prepared statements
80+
----------------------------------------------------------------
81+
82+
PREPARE prepared_floor_by(numeric) AS SELECT diffix.floor_by(discount, $1), count(*) FROM test_customers GROUP BY 1;
83+
EXECUTE prepared_floor_by(2.0);

0 commit comments

Comments
 (0)