1515#include "utils/lsyscache.h"
1616
1717#include "pg_diffix/auth.h"
18+ #include "pg_diffix/node_funcs.h"
1819#include "pg_diffix/oid_cache.h"
1920#include "pg_diffix/query/allowed_objects.h"
2021#include "pg_diffix/query/validation.h"
@@ -54,6 +55,8 @@ void verify_utility_command(Node *utility_stmt)
5455 case T_DeallocateStmt :
5556 case T_FetchStmt :
5657 case T_ClosePortalStmt :
58+ case T_PrepareStmt :
59+ case T_ExecuteStmt :
5760 break ;
5861 default :
5962 FAILWITH ("Statement requires direct access level." );
@@ -225,7 +228,7 @@ static void verify_bucket_expression(Node *node)
225228
226229 for (int i = 1 ; i < list_length (func_expr -> args ); i ++ )
227230 {
228- if (!IsA (unwrap_cast ((Node * )list_nth (func_expr -> args , i )), Const ))
231+ if (!is_stable_expression (unwrap_cast ((Node * )list_nth (func_expr -> args , i ))))
229232 FAILWITH_LOCATION (func_expr -> location , "Non-primary arguments for a bucket function have to be simple constants." );
230233 }
231234 }
@@ -234,10 +237,9 @@ static void verify_bucket_expression(Node *node)
234237 OpExpr * op_expr = (OpExpr * )node ;
235238 FAILWITH_LOCATION (op_expr -> location , "Use of operators to define buckets is not supported." );
236239 }
237- else if (IsA (node , Const ))
240+ else if (is_stable_expression (node ))
238241 {
239- Const * const_expr = (Const * )node ;
240- FAILWITH_LOCATION (const_expr -> location , "Simple constants are not allowed as bucket expressions." );
242+ FAILWITH_LOCATION (exprLocation (node ), "Simple constants are not allowed as bucket expressions." );
241243 }
242244 else if (IsA (node , RelabelType ))
243245 {
@@ -262,14 +264,17 @@ static void verify_bucket_expression(Node *node)
262264 }
263265}
264266
265- static void verify_substring (FuncExpr * func_expr )
267+ static void verify_substring (FuncExpr * func_expr , ParamListInfo bound_params )
266268{
267269 Node * node = unwrap_cast (list_nth (func_expr -> args , 1 ));
268- Assert (IsA (node , Const )); /* Checked by prior validations */
269- Const * second_arg = (Const * )node ;
270-
271- if (DatumGetUInt32 (second_arg -> constvalue ) != 1 )
272- FAILWITH_LOCATION (second_arg -> location , "Generalization used in the query is not allowed in untrusted access level." );
270+ Assert (is_stable_expression (node )); /* Checked by prior validations */
271+ Oid type ;
272+ Datum value ;
273+ bool isnull ;
274+ get_stable_expression_value (node , bound_params , & type , & value , & isnull );
275+
276+ if (DatumGetUInt32 (value ) != 1 )
277+ FAILWITH_LOCATION (exprLocation (node ), "Generalization used in the query is not allowed in untrusted access level." );
273278}
274279
275280/* money-style numbers, i.e. 1, 2, or 5 preceeded by or followed by zeros: ⟨... 0.1, 0.2, 0.5, 1, 2, 5, 10, ...⟩ */
@@ -286,29 +291,32 @@ static bool is_money_style(double number)
286291}
287292
288293/* Expects the expression being the second argument to `round_by` et al. */
289- static void verify_bin_size (Node * range_expr )
294+ static void verify_bin_size (Node * range_expr , ParamListInfo bound_params )
290295{
291296 Node * range_node = unwrap_cast (range_expr );
292- Assert (IsA (range_node , Const )); /* Checked by prior validations */
293- Const * range_const = (Const * )range_node ;
297+ Assert (is_stable_expression (range_node )); /* Checked by prior validations */
298+ Oid type ;
299+ Datum value ;
300+ bool isnull ;
301+ get_stable_expression_value (range_node , bound_params , & type , & value , & isnull );
294302
295- if (!is_supported_numeric_type (range_const -> consttype ))
296- FAILWITH_LOCATION (range_const -> location , "Unsupported constant type used in generalization." );
303+ if (!is_supported_numeric_type (type ))
304+ FAILWITH_LOCATION (exprLocation ( range_node ) , "Unsupported constant type used in generalization." );
297305
298- if (!is_money_style (numeric_value_to_double (range_const -> consttype , range_const -> constvalue )))
299- FAILWITH_LOCATION (range_const -> location , "Generalization used in the query is not allowed in untrusted access level." );
306+ if (!is_money_style (numeric_value_to_double (type , value )))
307+ FAILWITH_LOCATION (exprLocation ( range_node ) , "Generalization used in the query is not allowed in untrusted access level." );
300308}
301309
302- static void verify_generalization (Node * node )
310+ static void verify_generalization (Node * node , ParamListInfo bound_params )
303311{
304312 if (IsA (node , FuncExpr ))
305313 {
306314 FuncExpr * func_expr = (FuncExpr * )node ;
307315
308316 if (is_substring_builtin (func_expr -> funcid ))
309- verify_substring (func_expr );
317+ verify_substring (func_expr , bound_params );
310318 else if (is_implicit_range_udf_untrusted (func_expr -> funcid ))
311- verify_bin_size ((Node * )list_nth (func_expr -> args , 1 ));
319+ verify_bin_size ((Node * )list_nth (func_expr -> args , 1 ), bound_params );
312320 else if (is_implicit_range_builtin_untrusted (func_expr -> funcid ))
313321 ;
314322 else
@@ -317,7 +325,7 @@ static void verify_generalization(Node *node)
317325}
318326
319327/* Should be run on anonymizing queries only. */
320- void verify_bucket_expressions (Query * query )
328+ void verify_bucket_expressions (Query * query , ParamListInfo bound_params )
321329{
322330 AccessLevel access_level = get_session_access_level ();
323331
@@ -333,7 +341,7 @@ void verify_bucket_expressions(Query *query)
333341 Node * expr = (Node * )lfirst (cell );
334342 verify_bucket_expression (expr );
335343 if (access_level == ACCESS_ANONYMIZED_UNTRUSTED )
336- verify_generalization (expr );
344+ verify_generalization (expr , bound_params );
337345 }
338346}
339347
0 commit comments