11#include "postgres.h"
22
33#include "access/sysattr.h"
4+ #include "catalog/pg_type.h"
5+ #include "utils/builtins.h"
46#include "utils/fmgroids.h"
57#include "utils/fmgrtab.h"
68#include "utils/lsyscache.h"
@@ -39,7 +41,9 @@ static const FunctionByName g_allowed_builtins[] = {
3941 (FunctionByName ){.name = "dtoi2" , .primary_arg = 0 },
4042 (FunctionByName ){.name = "dtoi4" , .primary_arg = 0 },
4143 (FunctionByName ){.name = "dtoi8" , .primary_arg = 0 },
44+ (FunctionByName ){.name = "numeric_int2" , .primary_arg = 0 },
4245 (FunctionByName ){.name = "numeric_int4" , .primary_arg = 0 },
46+ (FunctionByName ){.name = "numeric_int8" , .primary_arg = 0 },
4347 /* substring */
4448 (FunctionByName ){.name = "text_substr" , .primary_arg = 0 },
4549 (FunctionByName ){.name = "text_substr_no_len" , .primary_arg = 0 },
@@ -124,6 +128,22 @@ static AllowedCols g_pg_catalog_allowed_cols[] = {
124128 /**/
125129};
126130
131+ static const char * const g_decimal_integer_casts [] = {
132+ "numeric_int2" , "numeric_int4" , "numeric_int8" , "dtoi2" , "dtoi4" , "dtoi8" , "ftoi2" , "ftoi4" , "ftoi8" ,
133+ /**/
134+ };
135+
136+ static const char * const g_extract_functions [] = {
137+ "extract_date" , "extract_timestamp" , "extract_timestamptz" , "timestamp_part" , "timestamptz_part" ,
138+ /**/
139+ };
140+
141+ static const char * const g_integer_extract_fields [] = {
142+ "minute" , "hour" , "day" , "dow" , "isodow" , "doy" , "week" , "month" , "quarter" , "year" , "isoyear" , "decade" ,
143+ "century" , "millennium" ,
144+ /**/
145+ };
146+
127147static void prepare_pg_catalog_allowed (Oid relation_oid , AllowedCols * allowed_cols )
128148{
129149 MemoryContext old_context = MemoryContextSwitchTo (TopMemoryContext );
@@ -166,6 +186,16 @@ static const FmgrBuiltin *fmgr_isbuiltin(Oid id)
166186 return & fmgr_builtins [index ];
167187}
168188
189+ static bool is_member_of (const char * s , const char * const array [], int length )
190+ {
191+ for (int i = 0 ; i < length ; i ++ )
192+ {
193+ if (strcmp (array [i ], s ) == 0 )
194+ return true;
195+ }
196+ return false;
197+ }
198+
169199static bool is_func_member_of (Oid funcoid , const FunctionByName func_array [], int length )
170200{
171201 const FmgrBuiltin * fmgr_builtin = fmgr_isbuiltin (funcoid );
@@ -185,13 +215,7 @@ static bool is_funcname_member_of(Oid funcoid, const char *const name_array[], i
185215{
186216 const FmgrBuiltin * fmgr_builtin = fmgr_isbuiltin (funcoid );
187217 if (fmgr_builtin != NULL )
188- {
189- for (int i = 0 ; i < length ; i ++ )
190- {
191- if (strcmp (name_array [i ], fmgr_builtin -> funcName ) == 0 )
192- return true;
193- }
194- }
218+ return is_member_of (fmgr_builtin -> funcName , name_array , length );
195219
196220 return false;
197221}
@@ -224,9 +248,38 @@ int primary_arg_index(Oid funcoid)
224248 FAILWITH ("Cannot identify the primary argument position for funcid %u." , funcoid );
225249}
226250
227- bool is_allowed_cast (Oid funcoid )
251+ bool is_allowed_cast (const FuncExpr * func_expr )
228252{
229- return is_funcname_member_of (funcoid , g_allowed_casts , ARRAY_LENGTH (g_allowed_casts ));
253+ if (is_funcname_member_of (func_expr -> funcid , g_allowed_casts , ARRAY_LENGTH (g_allowed_casts )))
254+ {
255+ return true;
256+ }
257+ else if (is_funcname_member_of (func_expr -> funcid , g_decimal_integer_casts , ARRAY_LENGTH (g_decimal_integer_casts )))
258+ {
259+ /*
260+ * Special case, where a `numeric_int4` cast is called on variants of `extract` which return
261+ * integer numbers, e.g. `cast(extract(minute from ...) as integer)`.
262+ */
263+ Node * cast_arg = linitial (func_expr -> args );
264+ if (IsA (cast_arg , FuncExpr ))
265+ {
266+ FuncExpr * cast_arg_expr = (FuncExpr * )cast_arg ;
267+ if (is_funcname_member_of (cast_arg_expr -> funcid , g_extract_functions , ARRAY_LENGTH (g_extract_functions )) ||
268+ cast_arg_expr -> funcid == F_DATE_PART_TEXT_DATE )
269+ {
270+ Node * extract_field = linitial (cast_arg_expr -> args );
271+ if (IsA (extract_field , Const ))
272+ {
273+ Const * extract_field_const = (Const * )extract_field ;
274+ Assert (extract_field_const -> consttype == TEXTOID );
275+ const char * field = TextDatumGetCString (extract_field_const -> constvalue );
276+
277+ return is_member_of (field , g_integer_extract_fields , ARRAY_LENGTH (g_integer_extract_fields ));
278+ }
279+ }
280+ }
281+ }
282+ return false;
230283}
231284
232285bool is_implicit_range_udf_untrusted (Oid funcoid )
@@ -279,13 +332,10 @@ bool is_allowed_pg_catalog_rte(Oid relation_oid, const Bitmapset *selected_cols)
279332 char * rel_name = get_rel_name (relation_oid );
280333
281334 /* Then check if the entire relation is allowed. */
282- for ( int i = 0 ; i < ARRAY_LENGTH (g_pg_catalog_allowed_rels ); i ++ )
335+ if ( is_member_of ( rel_name , g_pg_catalog_allowed_rels , ARRAY_LENGTH (g_pg_catalog_allowed_rels )) )
283336 {
284- if (strcmp (g_pg_catalog_allowed_rels [i ], rel_name ) == 0 )
285- {
286- pfree (rel_name );
287- return true;
288- }
337+ pfree (rel_name );
338+ return true;
289339 }
290340
291341 /* Otherwise specific selected columns must be checked against the allow-list. */
0 commit comments