Skip to content

Commit caa0612

Browse files
committed
Allow cast(extract ...) as integer
1 parent 6711291 commit caa0612

File tree

7 files changed

+111
-19
lines changed

7 files changed

+111
-19
lines changed

pg_diffix/query/allowed_objects.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define PG_DIFFIX_ALLOWED_OBJECTS_H
33

44
#include "nodes/bitmapset.h"
5+
#include "nodes/primnodes.h"
56

67
/*
78
* Returns whether the OID points to a function (or operator) allowed in defining buckets.
@@ -17,7 +18,7 @@ extern int primary_arg_index(Oid funcoid);
1718
/*
1819
* Returns whether the OID points to a cast allowed in defining buckets.
1920
*/
20-
extern bool is_allowed_cast(Oid funcoid);
21+
extern bool is_allowed_cast(const FuncExpr *func_expr);
2122

2223
/*
2324
* Returns whether the OID points to a UDF being a implicit_range function, e.g. `ceil_by(x, 2.0)`,

src/node_funcs.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Node *unwrap_cast(Node *node)
2020
if (IsA(node, FuncExpr))
2121
{
2222
FuncExpr *func_expr = (FuncExpr *)node;
23-
if (is_allowed_cast(func_expr->funcid))
23+
if (is_allowed_cast(func_expr))
2424
{
2525
Assert(list_length(func_expr->args) == 1); /* All allowed casts require exactly one argument. */
2626
return unwrap_cast(linitial(func_expr->args));

src/query/allowed_objects.c

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "postgres.h"
22

33
#include "access/sysattr.h"
4+
#include "catalog/pg_type.h"
5+
#include "utils/builtins.h"
46
#include "utils/fmgroids.h"
57
#include "utils/fmgrtab.h"
68
#include "utils/lsyscache.h"
@@ -39,7 +41,9 @@ static const FunctionByName g_allowed_builtins[] = {
3941
(FunctionByName){.name = "dtoi2", .primary_arg = 0},
4042
(FunctionByName){.name = "dtoi4", .primary_arg = 0},
4143
(FunctionByName){.name = "dtoi8", .primary_arg = 0},
44+
(FunctionByName){.name = "numeric_int2", .primary_arg = 0},
4245
(FunctionByName){.name = "numeric_int4", .primary_arg = 0},
46+
(FunctionByName){.name = "numeric_int8", .primary_arg = 0},
4347
/* substring */
4448
(FunctionByName){.name = "text_substr", .primary_arg = 0},
4549
(FunctionByName){.name = "text_substr_no_len", .primary_arg = 0},
@@ -124,6 +128,22 @@ static AllowedCols g_pg_catalog_allowed_cols[] = {
124128
/**/
125129
};
126130

131+
static const char *const g_decimal_integer_casts[] = {
132+
"numeric_int2", "numeric_int4", "numeric_int8", "dtoi2", "dtoi4", "dtoi8", "ftoi2", "ftoi4", "ftoi8",
133+
/**/
134+
};
135+
136+
static const char *const g_extract_functions[] = {
137+
"extract_date", "extract_timestamp", "extract_timestamptz", "timestamp_part", "timestamptz_part",
138+
/**/
139+
};
140+
141+
static const char *const g_integer_extract_fields[] = {
142+
"minute", "hour", "day", "dow", "isodow", "doy", "week", "month", "quarter", "year", "isoyear", "decade",
143+
"century", "millennium",
144+
/**/
145+
};
146+
127147
static void prepare_pg_catalog_allowed(Oid relation_oid, AllowedCols *allowed_cols)
128148
{
129149
MemoryContext old_context = MemoryContextSwitchTo(TopMemoryContext);
@@ -166,6 +186,16 @@ static const FmgrBuiltin *fmgr_isbuiltin(Oid id)
166186
return &fmgr_builtins[index];
167187
}
168188

189+
static bool is_member_of(const char *s, const char *const array[], int length)
190+
{
191+
for (int i = 0; i < length; i++)
192+
{
193+
if (strcmp(array[i], s) == 0)
194+
return true;
195+
}
196+
return false;
197+
}
198+
169199
static bool is_func_member_of(Oid funcoid, const FunctionByName func_array[], int length)
170200
{
171201
const FmgrBuiltin *fmgr_builtin = fmgr_isbuiltin(funcoid);
@@ -185,13 +215,7 @@ static bool is_funcname_member_of(Oid funcoid, const char *const name_array[], i
185215
{
186216
const FmgrBuiltin *fmgr_builtin = fmgr_isbuiltin(funcoid);
187217
if (fmgr_builtin != NULL)
188-
{
189-
for (int i = 0; i < length; i++)
190-
{
191-
if (strcmp(name_array[i], fmgr_builtin->funcName) == 0)
192-
return true;
193-
}
194-
}
218+
return is_member_of(fmgr_builtin->funcName, name_array, length);
195219

196220
return false;
197221
}
@@ -224,9 +248,38 @@ int primary_arg_index(Oid funcoid)
224248
FAILWITH("Cannot identify the primary argument position for funcid %u.", funcoid);
225249
}
226250

227-
bool is_allowed_cast(Oid funcoid)
251+
bool is_allowed_cast(const FuncExpr *func_expr)
228252
{
229-
return is_funcname_member_of(funcoid, g_allowed_casts, ARRAY_LENGTH(g_allowed_casts));
253+
if (is_funcname_member_of(func_expr->funcid, g_allowed_casts, ARRAY_LENGTH(g_allowed_casts)))
254+
{
255+
return true;
256+
}
257+
else if (is_funcname_member_of(func_expr->funcid, g_decimal_integer_casts, ARRAY_LENGTH(g_decimal_integer_casts)))
258+
{
259+
/*
260+
* Special case, where a `numeric_int4` cast is called on variants of `extract` which return
261+
* integer numbers, e.g. `cast(extract(minute from ...) as integer)`.
262+
*/
263+
Node *cast_arg = linitial(func_expr->args);
264+
if (IsA(cast_arg, FuncExpr))
265+
{
266+
FuncExpr *cast_arg_expr = (FuncExpr *)cast_arg;
267+
if (is_funcname_member_of(cast_arg_expr->funcid, g_extract_functions, ARRAY_LENGTH(g_extract_functions)) ||
268+
cast_arg_expr->funcid == F_DATE_PART_TEXT_DATE)
269+
{
270+
Node *extract_field = linitial(cast_arg_expr->args);
271+
if (IsA(extract_field, Const))
272+
{
273+
Const *extract_field_const = (Const *)extract_field;
274+
Assert(extract_field_const->consttype == TEXTOID);
275+
const char *field = TextDatumGetCString(extract_field_const->constvalue);
276+
277+
return is_member_of(field, g_integer_extract_fields, ARRAY_LENGTH(g_integer_extract_fields));
278+
}
279+
}
280+
}
281+
}
282+
return false;
230283
}
231284

232285
bool is_implicit_range_udf_untrusted(Oid funcoid)
@@ -279,13 +332,10 @@ bool is_allowed_pg_catalog_rte(Oid relation_oid, const Bitmapset *selected_cols)
279332
char *rel_name = get_rel_name(relation_oid);
280333

281334
/* Then check if the entire relation is allowed. */
282-
for (int i = 0; i < ARRAY_LENGTH(g_pg_catalog_allowed_rels); i++)
335+
if (is_member_of(rel_name, g_pg_catalog_allowed_rels, ARRAY_LENGTH(g_pg_catalog_allowed_rels)))
283336
{
284-
if (strcmp(g_pg_catalog_allowed_rels[i], rel_name) == 0)
285-
{
286-
pfree(rel_name);
287-
return true;
288-
}
337+
pfree(rel_name);
338+
return true;
289339
}
290340

291341
/* Otherwise specific selected columns must be checked against the allow-list. */

src/query/anonymization.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
591591
if (IsA(node, FuncExpr))
592592
{
593593
FuncExpr *func_expr = (FuncExpr *)node;
594-
if (!is_allowed_cast(func_expr->funcid))
594+
if (!is_allowed_cast(func_expr))
595595
{
596596
char *func_name = get_func_name(func_expr->funcid);
597597
if (func_name)

src/query/validation.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,11 @@ static void verify_bucket_expression(Node *node)
231231
if (IsA(node, FuncExpr))
232232
{
233233
FuncExpr *func_expr = (FuncExpr *)node;
234-
if (is_allowed_cast(func_expr->funcid))
234+
if (is_allowed_cast(func_expr))
235235
{
236236
Assert(list_length(func_expr->args) == 1); /* All allowed casts require exactly one argument. */
237237
verify_bucket_expression(linitial(func_expr->args));
238+
return;
238239
}
239240

240241
if (!is_allowed_function(func_expr->funcid))

test/expected/datetime.out

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,32 @@ SELECT count(*) FROM test_datetime WHERE date_part('century', ts) = 21;
9797
9
9898
(1 row)
9999

100+
-- Datetime cast normalization
101+
-- Allowed because the cast is a noop
102+
SELECT cast(extract(minute from ts) as integer) as extract FROM test_datetime GROUP BY 1;
103+
extract
104+
---------
105+
0
106+
(1 row)
107+
108+
-- Disallowed because the cast is rounding
109+
SELECT cast(extract(epoch from ts) as integer) FROM test_datetime GROUP BY 1;
110+
ERROR: [PG_DIFFIX] Primary argument for a generalization function has to be a simple column reference.
111+
LINE 1: SELECT cast(extract(epoch from ts) as integer) FROM test_dat...
112+
^
113+
SELECT cast(extract(second from ts) as integer) FROM test_datetime GROUP BY 1;
114+
ERROR: [PG_DIFFIX] Primary argument for a generalization function has to be a simple column reference.
115+
LINE 1: SELECT cast(extract(second from ts) as integer) FROM test_da...
116+
^
117+
SELECT cast(extract(millisecond from ts) as integer) FROM test_datetime GROUP BY 1;
118+
ERROR: [PG_DIFFIX] Primary argument for a generalization function has to be a simple column reference.
119+
LINE 1: SELECT cast(extract(millisecond from ts) as integer) FROM te...
120+
^
121+
SELECT cast(extract(microsecond from ts) as integer) FROM test_datetime GROUP BY 1;
122+
ERROR: [PG_DIFFIX] Primary argument for a generalization function has to be a simple column reference.
123+
LINE 1: SELECT cast(extract(microsecond from ts) as integer) FROM te...
124+
^
125+
SELECT cast(extract(julian from ts) as integer) FROM test_datetime GROUP BY 1;
126+
ERROR: [PG_DIFFIX] Primary argument for a generalization function has to be a simple column reference.
127+
LINE 1: SELECT cast(extract(julian from ts) as integer) FROM test_da...
128+
^

test/sql/datetime.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,14 @@ SELECT tz, count(*) FROM test_datetime GROUP BY 1;
5858
SELECT count(*) FROM test_datetime WHERE date_trunc('year', ts) = '2012-01-01'::timestamp;
5959
SELECT count(*) FROM test_datetime WHERE extract(century from ts) = 21;
6060
SELECT count(*) FROM test_datetime WHERE date_part('century', ts) = 21;
61+
62+
-- Datetime cast normalization
63+
-- Allowed because the cast is a noop
64+
SELECT cast(extract(minute from ts) as integer) as extract FROM test_datetime GROUP BY 1;
65+
66+
-- Disallowed because the cast is rounding
67+
SELECT cast(extract(epoch from ts) as integer) FROM test_datetime GROUP BY 1;
68+
SELECT cast(extract(second from ts) as integer) FROM test_datetime GROUP BY 1;
69+
SELECT cast(extract(millisecond from ts) as integer) FROM test_datetime GROUP BY 1;
70+
SELECT cast(extract(microsecond from ts) as integer) FROM test_datetime GROUP BY 1;
71+
SELECT cast(extract(julian from ts) as integer) FROM test_datetime GROUP BY 1;

0 commit comments

Comments
 (0)