11#include "postgres.h"
22
33#include "access/sysattr.h"
4+ #include "catalog/pg_type.h"
5+ #include "utils/builtins.h"
46#include "utils/fmgroids.h"
57#include "utils/fmgrtab.h"
68#include "utils/lsyscache.h"
1315static const char * const g_allowed_casts [] = {
1416 "i2tod" , "i2tof" , "i2toi4" , "i4toi2" , "i4tod" , "i4tof" , "i8tod" , "i8tof" , "int48" , "int84" ,
1517 "ftod" , "dtof" ,
16- "int4_numeric" , "float4_numeric" , "float8_numeric" ,
18+ "int2_numeric" , " int4_numeric" , "int8_numeric " , "float4_numeric" , "float8_numeric" ,
1719 "numeric_float4" , "numeric_float8" ,
1820 "date_timestamptz" ,
1921 /**/
@@ -39,7 +41,9 @@ static const FunctionByName g_allowed_builtins[] = {
3941 (FunctionByName ){.name = "dtoi2" , .primary_arg = 0 },
4042 (FunctionByName ){.name = "dtoi4" , .primary_arg = 0 },
4143 (FunctionByName ){.name = "dtoi8" , .primary_arg = 0 },
44+ (FunctionByName ){.name = "numeric_int2" , .primary_arg = 0 },
4245 (FunctionByName ){.name = "numeric_int4" , .primary_arg = 0 },
46+ (FunctionByName ){.name = "numeric_int8" , .primary_arg = 0 },
4347 /* substring */
4448 (FunctionByName ){.name = "text_substr" , .primary_arg = 0 },
4549 (FunctionByName ){.name = "text_substr_no_len" , .primary_arg = 0 },
@@ -124,6 +128,16 @@ static AllowedCols g_pg_catalog_allowed_cols[] = {
124128 /**/
125129};
126130
131+ static const char * const g_decimal_integer_casts [] = {
132+ "numeric_int2" , "numeric_int4" , "numeric_int8" , "dtoi2" , "dtoi4" , "dtoi8" , "ftoi2" , "ftoi4" , "ftoi8" ,
133+ /**/
134+ };
135+
136+ static const char * const g_extract_functions [] = {
137+ "extract_date" , "extract_timestamp" , "extract_timestamptz" , "timestamp_part" , "timestamptz_part" ,
138+ /**/
139+ };
140+
127141static void prepare_pg_catalog_allowed (Oid relation_oid , AllowedCols * allowed_cols )
128142{
129143 MemoryContext old_context = MemoryContextSwitchTo (TopMemoryContext );
@@ -166,6 +180,16 @@ static const FmgrBuiltin *fmgr_isbuiltin(Oid id)
166180 return & fmgr_builtins [index ];
167181}
168182
183+ static bool is_member_of (const char * s , const char * const array [], int length )
184+ {
185+ for (int i = 0 ; i < length ; i ++ )
186+ {
187+ if (strcmp (array [i ], s ) == 0 )
188+ return true;
189+ }
190+ return false;
191+ }
192+
169193static bool is_func_member_of (Oid funcoid , const FunctionByName func_array [], int length )
170194{
171195 const FmgrBuiltin * fmgr_builtin = fmgr_isbuiltin (funcoid );
@@ -185,13 +209,7 @@ static bool is_funcname_member_of(Oid funcoid, const char *const name_array[], i
185209{
186210 const FmgrBuiltin * fmgr_builtin = fmgr_isbuiltin (funcoid );
187211 if (fmgr_builtin != NULL )
188- {
189- for (int i = 0 ; i < length ; i ++ )
190- {
191- if (strcmp (name_array [i ], fmgr_builtin -> funcName ) == 0 )
192- return true;
193- }
194- }
212+ return is_member_of (fmgr_builtin -> funcName , name_array , length );
195213
196214 return false;
197215}
@@ -224,9 +242,25 @@ int primary_arg_index(Oid funcoid)
224242 FAILWITH ("Cannot identify the primary argument position for funcid %u." , funcoid );
225243}
226244
227- bool is_allowed_cast (Oid funcoid )
245+ bool is_allowed_cast (const FuncExpr * func_expr )
228246{
229- return is_funcname_member_of (funcoid , g_allowed_casts , ARRAY_LENGTH (g_allowed_casts ));
247+ if (is_funcname_member_of (func_expr -> funcid , g_allowed_casts , ARRAY_LENGTH (g_allowed_casts )))
248+ {
249+ return true;
250+ }
251+ else if (is_funcname_member_of (func_expr -> funcid , g_decimal_integer_casts , ARRAY_LENGTH (g_decimal_integer_casts )))
252+ {
253+ /* Handle cases like `cast(extract(minute from ...) as integer)`. */
254+ Node * cast_arg = linitial (func_expr -> args );
255+ if (IsA (cast_arg , FuncExpr ))
256+ {
257+ FuncExpr * cast_arg_expr = (FuncExpr * )cast_arg ;
258+ if (is_funcname_member_of (cast_arg_expr -> funcid , g_extract_functions , ARRAY_LENGTH (g_extract_functions )) ||
259+ cast_arg_expr -> funcid == F_DATE_PART_TEXT_DATE )
260+ return true;
261+ }
262+ }
263+ return false;
230264}
231265
232266bool is_implicit_range_udf_untrusted (Oid funcoid )
@@ -279,13 +313,10 @@ bool is_allowed_pg_catalog_rte(Oid relation_oid, const Bitmapset *selected_cols)
279313 char * rel_name = get_rel_name (relation_oid );
280314
281315 /* Then check if the entire relation is allowed. */
282- for ( int i = 0 ; i < ARRAY_LENGTH (g_pg_catalog_allowed_rels ); i ++ )
316+ if ( is_member_of ( rel_name , g_pg_catalog_allowed_rels , ARRAY_LENGTH (g_pg_catalog_allowed_rels )) )
283317 {
284- if (strcmp (g_pg_catalog_allowed_rels [i ], rel_name ) == 0 )
285- {
286- pfree (rel_name );
287- return true;
288- }
318+ pfree (rel_name );
319+ return true;
289320 }
290321
291322 /* Otherwise specific selected columns must be checked against the allow-list. */
0 commit comments