@@ -54,6 +54,8 @@ void verify_utility_command(Node *utility_stmt)
5454 case T_DeallocateStmt :
5555 case T_FetchStmt :
5656 case T_ClosePortalStmt :
57+ case T_PrepareStmt :
58+ case T_ExecuteStmt :
5759 break ;
5860 default :
5961 FAILWITH ("Statement requires direct access level." );
@@ -225,7 +227,7 @@ static void verify_bucket_expression(Node *node)
225227
226228 for (int i = 1 ; i < list_length (func_expr -> args ); i ++ )
227229 {
228- if (!IsA (unwrap_cast ((Node * )list_nth (func_expr -> args , i )), Const ))
230+ if (!is_simple_constant (unwrap_cast ((Node * )list_nth (func_expr -> args , i ))))
229231 FAILWITH_LOCATION (func_expr -> location , "Non-primary arguments for a bucket function have to be simple constants." );
230232 }
231233 }
@@ -234,10 +236,9 @@ static void verify_bucket_expression(Node *node)
234236 OpExpr * op_expr = (OpExpr * )node ;
235237 FAILWITH_LOCATION (op_expr -> location , "Use of operators to define buckets is not supported." );
236238 }
237- else if (IsA (node , Const ))
239+ else if (is_simple_constant (node ))
238240 {
239- Const * const_expr = (Const * )node ;
240- FAILWITH_LOCATION (const_expr -> location , "Simple constants are not allowed as bucket expressions." );
241+ FAILWITH_LOCATION (get_simple_constant_location (node ), "Simple constants are not allowed as bucket expressions." );
241242 }
242243 else if (IsA (node , RelabelType ))
243244 {
@@ -262,14 +263,17 @@ static void verify_bucket_expression(Node *node)
262263 }
263264}
264265
265- static void verify_substring (FuncExpr * func_expr )
266+ static void verify_substring (FuncExpr * func_expr , ParamListInfo bound_params )
266267{
267268 Node * node = unwrap_cast (list_nth (func_expr -> args , 1 ));
268- Assert (IsA (node , Const )); /* Checked by prior validations */
269- Const * second_arg = (Const * )node ;
270-
271- if (DatumGetUInt32 (second_arg -> constvalue ) != 1 )
272- FAILWITH_LOCATION (second_arg -> location , "Generalization used in the query is not allowed in untrusted access level." );
269+ Assert (is_simple_constant (node )); /* Checked by prior validations */
270+ Oid type ;
271+ Datum value ;
272+ bool isnull ;
273+ get_simple_constant_typed_value (node , bound_params , & type , & value , & isnull );
274+
275+ if (DatumGetUInt32 (value ) != 1 )
276+ FAILWITH_LOCATION (get_simple_constant_location (node ), "Generalization used in the query is not allowed in untrusted access level." );
273277}
274278
275279/* money-style numbers, i.e. 1, 2, or 5 preceeded by or followed by zeros: ⟨... 0.1, 0.2, 0.5, 1, 2, 5, 10, ...⟩ */
@@ -286,29 +290,32 @@ static bool is_money_style(double number)
286290}
287291
288292/* Expects the expression being the second argument to `round_by` et al. */
289- static void verify_bin_size (Node * range_expr )
293+ static void verify_bin_size (Node * range_expr , ParamListInfo bound_params )
290294{
291295 Node * range_node = unwrap_cast (range_expr );
292- Assert (IsA (range_node , Const )); /* Checked by prior validations */
293- Const * range_const = (Const * )range_node ;
296+ Assert (is_simple_constant (range_node )); /* Checked by prior validations */
297+ Oid type ;
298+ Datum value ;
299+ bool isnull ;
300+ get_simple_constant_typed_value (range_node , bound_params , & type , & value , & isnull );
294301
295- if (!is_supported_numeric_type (range_const -> consttype ))
296- FAILWITH_LOCATION (range_const -> location , "Unsupported constant type used in generalization." );
302+ if (!is_supported_numeric_type (type ))
303+ FAILWITH_LOCATION (get_simple_constant_location ( range_node ) , "Unsupported constant type used in generalization." );
297304
298- if (!is_money_style (numeric_value_to_double (range_const -> consttype , range_const -> constvalue )))
299- FAILWITH_LOCATION (range_const -> location , "Generalization used in the query is not allowed in untrusted access level." );
305+ if (!is_money_style (numeric_value_to_double (type , value )))
306+ FAILWITH_LOCATION (get_simple_constant_location ( range_node ) , "Generalization used in the query is not allowed in untrusted access level." );
300307}
301308
302- static void verify_generalization (Node * node )
309+ static void verify_generalization (Node * node , ParamListInfo bound_params )
303310{
304311 if (IsA (node , FuncExpr ))
305312 {
306313 FuncExpr * func_expr = (FuncExpr * )node ;
307314
308315 if (is_substring_builtin (func_expr -> funcid ))
309- verify_substring (func_expr );
316+ verify_substring (func_expr , bound_params );
310317 else if (is_implicit_range_udf_untrusted (func_expr -> funcid ))
311- verify_bin_size ((Node * )list_nth (func_expr -> args , 1 ));
318+ verify_bin_size ((Node * )list_nth (func_expr -> args , 1 ), bound_params );
312319 else if (is_implicit_range_builtin_untrusted (func_expr -> funcid ))
313320 ;
314321 else
@@ -317,7 +324,7 @@ static void verify_generalization(Node *node)
317324}
318325
319326/* Should be run on anonymizing queries only. */
320- void verify_bucket_expressions (Query * query )
327+ void verify_bucket_expressions (Query * query , ParamListInfo bound_params )
321328{
322329 AccessLevel access_level = get_session_access_level ();
323330
@@ -333,7 +340,7 @@ void verify_bucket_expressions(Query *query)
333340 Node * expr = (Node * )lfirst (cell );
334341 verify_bucket_expression (expr );
335342 if (access_level == ACCESS_ANONYMIZED_UNTRUSTED )
336- verify_generalization (expr );
343+ verify_generalization (expr , bound_params );
337344 }
338345}
339346
0 commit comments