Skip to content

Commit 6d51ee2

Browse files
authored
Merge pull request #452 from diffix/piotr/normalize-for-metabase-date_trunc
Allow `cast(extract ...) as integer`
2 parents a229059 + 829bc40 commit 6d51ee2

File tree

7 files changed

+72
-20
lines changed

7 files changed

+72
-20
lines changed

pg_diffix/query/allowed_objects.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define PG_DIFFIX_ALLOWED_OBJECTS_H
33

44
#include "nodes/bitmapset.h"
5+
#include "nodes/primnodes.h"
56

67
/*
78
* Returns whether the OID points to a function (or operator) allowed in defining buckets.
@@ -17,7 +18,7 @@ extern int primary_arg_index(Oid funcoid);
1718
/*
1819
* Returns whether the OID points to a cast allowed in defining buckets.
1920
*/
20-
extern bool is_allowed_cast(Oid funcoid);
21+
extern bool is_allowed_cast(const FuncExpr *func_expr);
2122

2223
/*
2324
* Returns whether the OID points to a UDF being a implicit_range function, e.g. `ceil_by(x, 2.0)`,

src/node_funcs.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Node *unwrap_cast(Node *node)
2020
if (IsA(node, FuncExpr))
2121
{
2222
FuncExpr *func_expr = (FuncExpr *)node;
23-
if (is_allowed_cast(func_expr->funcid))
23+
if (is_allowed_cast(func_expr))
2424
{
2525
Assert(list_length(func_expr->args) == 1); /* All allowed casts require exactly one argument. */
2626
return unwrap_cast(linitial(func_expr->args));

src/query/allowed_objects.c

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "postgres.h"
22

33
#include "access/sysattr.h"
4+
#include "catalog/pg_type.h"
5+
#include "utils/builtins.h"
46
#include "utils/fmgroids.h"
57
#include "utils/fmgrtab.h"
68
#include "utils/lsyscache.h"
@@ -13,7 +15,7 @@
1315
static const char *const g_allowed_casts[] = {
1416
"i2tod", "i2tof", "i2toi4", "i4toi2", "i4tod", "i4tof", "i8tod", "i8tof", "int48", "int84",
1517
"ftod", "dtof",
16-
"int4_numeric", "float4_numeric", "float8_numeric",
18+
"int2_numeric", "int4_numeric", "int8_numeric", "float4_numeric", "float8_numeric",
1719
"numeric_float4", "numeric_float8",
1820
"date_timestamptz",
1921
/**/
@@ -39,7 +41,9 @@ static const FunctionByName g_allowed_builtins[] = {
3941
(FunctionByName){.name = "dtoi2", .primary_arg = 0},
4042
(FunctionByName){.name = "dtoi4", .primary_arg = 0},
4143
(FunctionByName){.name = "dtoi8", .primary_arg = 0},
44+
(FunctionByName){.name = "numeric_int2", .primary_arg = 0},
4245
(FunctionByName){.name = "numeric_int4", .primary_arg = 0},
46+
(FunctionByName){.name = "numeric_int8", .primary_arg = 0},
4347
/* substring */
4448
(FunctionByName){.name = "text_substr", .primary_arg = 0},
4549
(FunctionByName){.name = "text_substr_no_len", .primary_arg = 0},
@@ -124,6 +128,16 @@ static AllowedCols g_pg_catalog_allowed_cols[] = {
124128
/**/
125129
};
126130

131+
static const char *const g_decimal_integer_casts[] = {
132+
"numeric_int2", "numeric_int4", "numeric_int8", "dtoi2", "dtoi4", "dtoi8", "ftoi2", "ftoi4", "ftoi8",
133+
/**/
134+
};
135+
136+
static const char *const g_extract_functions[] = {
137+
"extract_date", "extract_timestamp", "extract_timestamptz", "timestamp_part", "timestamptz_part",
138+
/**/
139+
};
140+
127141
static void prepare_pg_catalog_allowed(Oid relation_oid, AllowedCols *allowed_cols)
128142
{
129143
MemoryContext old_context = MemoryContextSwitchTo(TopMemoryContext);
@@ -166,6 +180,16 @@ static const FmgrBuiltin *fmgr_isbuiltin(Oid id)
166180
return &fmgr_builtins[index];
167181
}
168182

183+
static bool is_member_of(const char *s, const char *const array[], int length)
184+
{
185+
for (int i = 0; i < length; i++)
186+
{
187+
if (strcmp(array[i], s) == 0)
188+
return true;
189+
}
190+
return false;
191+
}
192+
169193
static bool is_func_member_of(Oid funcoid, const FunctionByName func_array[], int length)
170194
{
171195
const FmgrBuiltin *fmgr_builtin = fmgr_isbuiltin(funcoid);
@@ -185,13 +209,7 @@ static bool is_funcname_member_of(Oid funcoid, const char *const name_array[], i
185209
{
186210
const FmgrBuiltin *fmgr_builtin = fmgr_isbuiltin(funcoid);
187211
if (fmgr_builtin != NULL)
188-
{
189-
for (int i = 0; i < length; i++)
190-
{
191-
if (strcmp(name_array[i], fmgr_builtin->funcName) == 0)
192-
return true;
193-
}
194-
}
212+
return is_member_of(fmgr_builtin->funcName, name_array, length);
195213

196214
return false;
197215
}
@@ -224,9 +242,25 @@ int primary_arg_index(Oid funcoid)
224242
FAILWITH("Cannot identify the primary argument position for funcid %u.", funcoid);
225243
}
226244

227-
bool is_allowed_cast(Oid funcoid)
245+
bool is_allowed_cast(const FuncExpr *func_expr)
228246
{
229-
return is_funcname_member_of(funcoid, g_allowed_casts, ARRAY_LENGTH(g_allowed_casts));
247+
if (is_funcname_member_of(func_expr->funcid, g_allowed_casts, ARRAY_LENGTH(g_allowed_casts)))
248+
{
249+
return true;
250+
}
251+
else if (is_funcname_member_of(func_expr->funcid, g_decimal_integer_casts, ARRAY_LENGTH(g_decimal_integer_casts)))
252+
{
253+
/* Handle cases like `cast(extract(minute from ...) as integer)`. */
254+
Node *cast_arg = linitial(func_expr->args);
255+
if (IsA(cast_arg, FuncExpr))
256+
{
257+
FuncExpr *cast_arg_expr = (FuncExpr *)cast_arg;
258+
if (is_funcname_member_of(cast_arg_expr->funcid, g_extract_functions, ARRAY_LENGTH(g_extract_functions)) ||
259+
cast_arg_expr->funcid == F_DATE_PART_TEXT_DATE)
260+
return true;
261+
}
262+
}
263+
return false;
230264
}
231265

232266
bool is_implicit_range_udf_untrusted(Oid funcoid)
@@ -279,13 +313,10 @@ bool is_allowed_pg_catalog_rte(Oid relation_oid, const Bitmapset *selected_cols)
279313
char *rel_name = get_rel_name(relation_oid);
280314

281315
/* Then check if the entire relation is allowed. */
282-
for (int i = 0; i < ARRAY_LENGTH(g_pg_catalog_allowed_rels); i++)
316+
if (is_member_of(rel_name, g_pg_catalog_allowed_rels, ARRAY_LENGTH(g_pg_catalog_allowed_rels)))
283317
{
284-
if (strcmp(g_pg_catalog_allowed_rels[i], rel_name) == 0)
285-
{
286-
pfree(rel_name);
287-
return true;
288-
}
318+
pfree(rel_name);
319+
return true;
289320
}
290321

291322
/* Otherwise specific selected columns must be checked against the allow-list. */

src/query/anonymization.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ static bool collect_seed_material(Node *node, CollectMaterialContext *context)
591591
if (IsA(node, FuncExpr))
592592
{
593593
FuncExpr *func_expr = (FuncExpr *)node;
594-
if (!is_allowed_cast(func_expr->funcid))
594+
if (!is_allowed_cast(func_expr))
595595
{
596596
char *func_name = get_func_name(func_expr->funcid);
597597
if (func_name)

src/query/validation.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,11 @@ static void verify_bucket_expression(Node *node)
271271
if (IsA(node, FuncExpr))
272272
{
273273
FuncExpr *func_expr = (FuncExpr *)node;
274-
if (is_allowed_cast(func_expr->funcid))
274+
if (is_allowed_cast(func_expr))
275275
{
276276
Assert(list_length(func_expr->args) == 1); /* All allowed casts require exactly one argument. */
277277
verify_bucket_expression(linitial(func_expr->args));
278+
return;
278279
}
279280

280281
if (!is_allowed_function(func_expr->funcid))

test/expected/datetime.out

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,17 @@ SELECT count(*) FROM test_datetime WHERE date_part('century', ts) = 21;
9797
9
9898
(1 row)
9999

100+
-- Datetime extract cast to integer
101+
SELECT cast(extract(minute from ts) as integer) as extract FROM test_datetime GROUP BY 1;
102+
extract
103+
---------
104+
0
105+
(1 row)
106+
107+
-- Allowed despite the cast being rounding
108+
SELECT cast(extract(second from ts) as integer) as extract FROM test_datetime GROUP BY 1;
109+
extract
110+
---------
111+
0
112+
(1 row)
113+

test/sql/datetime.sql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,8 @@ SELECT tz, count(*) FROM test_datetime GROUP BY 1;
5858
SELECT count(*) FROM test_datetime WHERE date_trunc('year', ts) = '2012-01-01'::timestamp;
5959
SELECT count(*) FROM test_datetime WHERE extract(century from ts) = 21;
6060
SELECT count(*) FROM test_datetime WHERE date_part('century', ts) = 21;
61+
62+
-- Datetime extract cast to integer
63+
SELECT cast(extract(minute from ts) as integer) as extract FROM test_datetime GROUP BY 1;
64+
-- Allowed despite the cast being rounding
65+
SELECT cast(extract(second from ts) as integer) as extract FROM test_datetime GROUP BY 1;

0 commit comments

Comments
 (0)