Skip to content

Commit 7cc66c3

Browse files
authored
Merge pull request #417 from diffix/edon/shared-agg-state
Support state sharing across aggregates
2 parents bb001e3 + f028769 commit 7cc66c3

File tree

5 files changed

+114
-25
lines changed

5 files changed

+114
-25
lines changed

pg_diffix/aggregation/bucket_scan.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,9 @@ extern Plan *make_bucket_scan(Plan *left_tree, AnonymizationContext *anon_contex
2121
*/
2222
extern bool is_bucket_scan(Plan *plan);
2323

24+
/*
25+
* Returns true if another aggregate in the bucket does identical transitions as the given Aggref.
26+
*/
27+
extern bool aggref_shares_state(Aggref *aggref);
28+
2429
#endif /* PG_DIFFIX_BUCKET_SCAN_H */

pg_diffix/aggregation/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
* The `finalize` function derives the final value (of type `final_type`) of the aggregator.
2929
* Temporary and return data should not be allocated in the state's memory context but in
3030
* the current memory context which is shorter lived. See below for information about memory.
31+
* Because state might be borrowed from another aggregator, `finalize` must be idempotent,
32+
* meaning multiple executions against the same state have to return the same result.
3133
*
3234
* The `explain` function returns a human-readable representation of the aggregator state.
3335
* As with `finalize`, the current memory context should be used for temporary and return values.
@@ -76,6 +78,8 @@
7678
*-------------------------------------------------------------------------
7779
*/
7880

81+
#define AGG_STATE_REDIRECTED NULL
82+
7983
/* Describes a single function call argument. */
8084
typedef struct ArgDescriptor
8185
{
@@ -125,9 +129,10 @@ typedef struct BucketAttribute
125129
BucketAttributeTag tag; /* Label or aggregate? */
126130
struct
127131
{
128-
Oid fn_oid; /* Agg function OID */
132+
Aggref *aggref; /* Expr of aggregate */
129133
ArgsDescriptor *args_desc; /* Agg arguments descriptor */
130134
const AnonAggFuncs *funcs; /* Agg funcs if tag=BUCKET_ANON_AGG */
135+
int redirect_to; /* If shared, points to attribute that owns the state */
131136
} agg; /* Populated if tag!=BUCKET_LABEL */
132137
int typ_len; /* Data type length */
133138
bool typ_byval; /* Data type is by value? */

src/aggregation/bucket_scan.c

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
* Aggregates can be found in tlist and qual. We need to export both in Agg's tlist because we
3737
* move the actual projection and qual to BucketScan. TLEs n+1..n+m will be the aggregates.
3838
* When rewriting expressions for proj/qual, we do a simple equality-based deduplication to
39-
* minimize aggregates in tlist. It is not very important to be smart about optimizing at this
40-
* stage because ExecInitAgg will take care of sharing aggregation state during execution.
41-
* Arguments to aggregates are untouched because they do not leave the node.
39+
* minimize aggregates in tlist. During execution, anonymizing aggregators will reuse state
40+
* if conditions in `can_share_agg_state` are met. Arguments to aggregates are untouched
41+
* because they do not leave the node.
4242
*
4343
* Projection/filtering:
4444
*
@@ -89,14 +89,60 @@ static inline bool has_star_bucket(BucketScanState *bucket_state)
8989
return linitial(bucket_state->buckets) != NULL;
9090
}
9191

92-
/* Memory context of currently executing BucketScan node. */
93-
MemoryContext g_current_bucket_context = NULL;
92+
/* State of currently executing bucket scan. */
93+
static BucketScanState *g_current_bucket_scan = NULL;
94+
95+
MemoryContext get_current_bucket_context(void);
96+
bool aggref_shares_state(Aggref *aggref);
97+
98+
/* Used by common.c to locate the bucket memory context. */
99+
MemoryContext get_current_bucket_context(void)
100+
{
101+
return g_current_bucket_scan != NULL
102+
? g_current_bucket_scan->bucket_context
103+
: NULL;
104+
}
105+
106+
/* Used by common.c to check if an agg has redirected state. */
107+
bool aggref_shares_state(Aggref *aggref)
108+
{
109+
if (g_current_bucket_scan == NULL)
110+
return false;
111+
112+
BucketDescriptor *bucket_desc = g_current_bucket_scan->bucket_desc;
113+
int num_atts = bucket_num_atts(bucket_desc);
114+
for (int i = bucket_desc->num_labels; i < num_atts; i++)
115+
{
116+
BucketAttribute *att = &bucket_desc->attrs[i];
117+
/* We use reference comparison to pinpoint exact position of aggregate. */
118+
if (att->agg.aggref == aggref)
119+
return i != att->agg.redirect_to;
120+
}
121+
122+
/* This should not happen, but we can't guarantee that the Aggref was not copied somewhere. */
123+
return false;
124+
}
94125

95126
/*-------------------------------------------------------------------------
96127
* CustomExecMethods
97128
*-------------------------------------------------------------------------
98129
*/
99130

131+
/*
132+
* Returns true if aggregates can share the same agg state.
133+
* This is possible when args, initial state, transition, and merge functions are identical.
134+
*/
135+
static bool can_share_agg_state(BucketAttribute *agg1, BucketAttribute *agg2)
136+
{
137+
const AnonAggFuncs *funcs1 = agg1->agg.funcs;
138+
const AnonAggFuncs *funcs2 = agg2->agg.funcs;
139+
140+
return funcs1->create_state == funcs2->create_state &&
141+
funcs1->transition == funcs2->transition &&
142+
funcs1->merge == funcs2->merge &&
143+
equal(agg1->agg.aggref->args, agg2->agg.aggref->args);
144+
}
145+
100146
/*
101147
* Populates `bucket_desc` field with type metadata.
102148
*/
@@ -131,16 +177,32 @@ static void init_bucket_descriptor(BucketScanState *bucket_state)
131177
{
132178
Aggref *aggref = castNode(Aggref, tle->expr);
133179
agg_funcs = find_agg_funcs(aggref->aggfnoid);
134-
att->agg.fn_oid = aggref->aggfnoid;
180+
att->agg.aggref = aggref;
135181
att->agg.funcs = agg_funcs;
136182
att->agg.args_desc = build_args_desc(aggref);
183+
att->agg.redirect_to = i; /* Pointing to itself means state is not shared. */
137184
att->tag = agg_funcs != NULL ? BUCKET_ANON_AGG : BUCKET_REGULAR_AGG;
138185
}
139186

140187
if (agg_funcs != NULL)
141188
{
142189
/* For anonymizing aggregators we describe finalized type. */
143190
agg_funcs->final_type(att->agg.args_desc, &att->final_type, &att->final_typmod, &att->final_collid);
191+
192+
/* Look back to check if there is a compatible agg which we can share state with. */
193+
for (int j = plan_data->num_labels; j < i; j++)
194+
{
195+
BucketAttribute *other_att = &bucket_desc->attrs[j];
196+
if (other_att->agg.funcs == NULL)
197+
continue;
198+
199+
if (can_share_agg_state(att, other_att))
200+
{
201+
Assert(i != plan_data->low_count_index); /* low_count is always unique. */
202+
att->agg.redirect_to = j;
203+
break;
204+
}
205+
}
144206
}
145207
else
146208
{
@@ -182,13 +244,13 @@ static void bucket_begin_scan(CustomScanState *css, EState *estate, int eflags)
182244

183245
static void fill_bucket_list(BucketScanState *bucket_state)
184246
{
185-
MemoryContext old_bucket_context = g_current_bucket_context;
186-
MemoryContext bucket_context = bucket_state->bucket_context;
247+
BucketScanState *old_bucket_scan = g_current_bucket_scan;
187248

188249
ExprContext *econtext = bucket_state->css.ss.ps.ps_ExprContext;
189250
MemoryContext per_tuple_memory = econtext->ecxt_per_tuple_memory;
190251
PlanState *outer_plan_state = outerPlanState(bucket_state);
191252

253+
MemoryContext bucket_context = bucket_state->bucket_context;
192254
BucketDescriptor *bucket_desc = bucket_state->bucket_desc;
193255
int num_atts = bucket_num_atts(bucket_desc);
194256
int low_count_index = bucket_desc->low_count_index;
@@ -201,7 +263,7 @@ static void fill_bucket_list(BucketScanState *bucket_state)
201263
{
202264
CHECK_FOR_INTERRUPTS();
203265

204-
g_current_bucket_context = bucket_context;
266+
g_current_bucket_scan = bucket_state;
205267
TupleTableSlot *outer_slot = ExecProcNode(outer_plan_state);
206268

207269
if (TupIsNull(outer_slot))
@@ -247,8 +309,8 @@ static void fill_bucket_list(BucketScanState *bucket_state)
247309
bucket_state->buckets = buckets;
248310
bucket_state->input_done = true;
249311

250-
/* Restore previous bucket context. */
251-
g_current_bucket_context = old_bucket_context;
312+
/* Restore previous bucket scan context. */
313+
g_current_bucket_scan = old_bucket_scan;
252314
}
253315

254316
static void run_hooks(BucketScanState *bucket_state)
@@ -289,9 +351,9 @@ static void finalize_bucket(Bucket *bucket, BucketDescriptor *bucket_desc, ExprC
289351
BucketAttribute *att = &bucket_desc->attrs[i];
290352
if (att->tag == BUCKET_ANON_AGG)
291353
{
292-
AnonAggState *agg_state = (AnonAggState *)DatumGetPointer(bucket->values[i]);
354+
int state_source_index = att->agg.redirect_to; /* If shared, points to some other non-NULL state. */
355+
AnonAggState *agg_state = (AnonAggState *)DatumGetPointer(bucket->values[state_source_index]);
293356
Assert(agg_state != NULL);
294-
Assert(agg_state->agg_funcs == att->agg.funcs);
295357
is_null[i] = false;
296358
values[i] = att->agg.funcs->finalize(agg_state, bucket, bucket_desc, &is_null[i]);
297359
}

src/aggregation/common.c

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
#define PG_GET_AGG_STATE(index) ((AnonAggState *)PG_GETARG_POINTER(index))
1616
#define PG_RETURN_AGG_STATE(state) PG_RETURN_POINTER(state)
1717

18-
/* Memory context of currently executing BucketScan node (if any). */
19-
extern MemoryContext g_current_bucket_context;
18+
/* Functions declared in bucket_scan.c. Depend on global state and should not be public API. */
19+
extern MemoryContext get_current_bucket_context(void);
20+
extern bool aggref_shares_state(Aggref *aggref);
2021

2122
PG_FUNCTION_INFO_V1(anon_agg_state_input);
2223
PG_FUNCTION_INFO_V1(anon_agg_state_output);
@@ -95,11 +96,16 @@ void merge_bucket(Bucket *destination, Bucket *source, BucketDescriptor *bucket_
9596
for (int i = bucket_desc->num_labels; i < num_atts; i++)
9697
{
9798
BucketAttribute *att = &bucket_desc->attrs[i];
98-
if (att->tag == BUCKET_ANON_AGG)
99+
if (att->tag == BUCKET_ANON_AGG &&
100+
i == att->agg.redirect_to /* Shared states need to be merged only once. */)
99101
{
100102
Assert(!source->is_null[i]);
101103
Assert(!destination->is_null[i]);
102-
att->agg.funcs->merge((AnonAggState *)destination->values[i], (AnonAggState *)source->values[i]);
104+
AnonAggState *dst_state = (AnonAggState *)destination->values[i];
105+
AnonAggState *src_state = (AnonAggState *)source->values[i];
106+
Assert(dst_state != AGG_STATE_REDIRECTED);
107+
Assert(src_state != AGG_STATE_REDIRECTED);
108+
att->agg.funcs->merge(dst_state, src_state);
103109
}
104110
}
105111
}
@@ -113,10 +119,16 @@ static AnonAggState *get_agg_state(PG_FUNCTION_ARGS)
113119
if (AggCheckCallContext(fcinfo, &bucket_context) != AGG_CONTEXT_AGGREGATE)
114120
FAILWITH("Aggregate called in non-aggregate context");
115121

116-
if (g_current_bucket_context != NULL)
117-
bucket_context = g_current_bucket_context;
118-
119122
Aggref *aggref = AggGetAggref(fcinfo);
123+
124+
if (get_current_bucket_context() != NULL)
125+
{
126+
if (aggref_shares_state(aggref))
127+
return AGG_STATE_REDIRECTED;
128+
129+
bucket_context = get_current_bucket_context();
130+
}
131+
120132
const AnonAggFuncs *agg_funcs = find_agg_funcs(aggref->aggfnoid);
121133

122134
if (unlikely(agg_funcs == NULL))
@@ -134,20 +146,23 @@ Datum anon_agg_state_input(PG_FUNCTION_ARGS)
134146
Datum anon_agg_state_output(PG_FUNCTION_ARGS)
135147
{
136148
AnonAggState *state = PG_GET_AGG_STATE(0);
149+
Assert(state != AGG_STATE_REDIRECTED); /* Won't happen outside of a BucketScan context. */
137150
const char *str = state->agg_funcs->explain(state);
138151
PG_RETURN_CSTRING(str);
139152
}
140153

141154
Datum anon_agg_state_transfn(PG_FUNCTION_ARGS)
142155
{
143156
AnonAggState *state = get_agg_state(fcinfo);
144-
state->agg_funcs->transition(state, PG_NARGS(), fcinfo->args);
157+
/* AGG_STATE_REDIRECTED means the owning aggregator will handle transitions. */
158+
if (state != AGG_STATE_REDIRECTED)
159+
state->agg_funcs->transition(state, PG_NARGS(), fcinfo->args);
145160
PG_RETURN_AGG_STATE(state);
146161
}
147162

148163
/*
149164
* This finalfunc is a dummy version which does nothing.
150-
* It only ensures that state is not null for empty buckets.
165+
* It only ensures that state is initialized for empty buckets.
151166
*/
152167
Datum anon_agg_state_finalfn(PG_FUNCTION_ARGS)
153168
{

src/aggregation/star_bucket.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ Bucket *star_bucket_hook(List *buckets, BucketDescriptor *bucket_desc)
4848
BucketAttribute *att = &bucket_desc->attrs[i];
4949
if (att->tag == BUCKET_ANON_AGG)
5050
/* Create an empty anon agg state and merge buckets into it. */
51-
star_bucket->values[i] = PointerGetDatum(create_anon_agg_state(att->agg.funcs, bucket_context, att->agg.args_desc));
51+
star_bucket->values[i] = PointerGetDatum(i != att->agg.redirect_to
52+
? AGG_STATE_REDIRECTED
53+
: create_anon_agg_state(att->agg.funcs, bucket_context, att->agg.args_desc));
5254
else if (att->tag == BUCKET_LABEL)
5355
set_text_label(star_bucket, i, att->final_type, bucket_context);
54-
else if (att->agg.fn_oid == g_oid_cache.is_suppress_bin)
56+
else if (att->agg.aggref->aggfnoid == g_oid_cache.is_suppress_bin)
5557
star_bucket->values[i] = BoolGetDatum(true);
5658
else
5759
star_bucket->is_null[i] = true;

0 commit comments

Comments
 (0)