Skip to content

Commit 25ef336

Browse files
authored
Merge pull request #436 from diffix/piotr/negative-summands
Support negative summands
2 parents fdcb0ff + 6cf58cc commit 25ef336

File tree

7 files changed

+110
-31
lines changed

7 files changed

+110
-31
lines changed

pg_diffix/aggregation/contribution_tracker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@ typedef contribution_t (*ContributionCombineFunc)(contribution_t x, contribution
2525
/* Casts x to double. */
2626
typedef double (*ContributionToDoubleFunc)(contribution_t x);
2727

28+
/* Computes absolute value. */
29+
typedef contribution_t (*ContributionAbsFunc)(contribution_t x);
30+
2831
typedef struct ContributionDescriptor
2932
{
3033
ContributionGreaterFunc contribution_greater;
3134
ContributionEqualFunc contribution_equal;
3235
ContributionCombineFunc contribution_combine;
3336
ContributionToDoubleFunc contribution_to_double;
37+
ContributionAbsFunc contribution_abs;
3438
contribution_t contribution_initial; /* Initial or "zero" value for a contribution */
3539
} ContributionDescriptor;
3640

src/aggregation/sum.c

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ static double finalize_sum_result(const SummableResultAccumulator *accumulator)
2020
*-------------------------------------------------------------------------
2121
*/
2222

23+
typedef ContributionTrackerState *SumLeg;
24+
2325
typedef struct SumState
2426
{
2527
AnonAggState base;
2628
int trackers_count;
2729
Oid summand_type;
28-
ContributionTrackerState *trackers[FLEXIBLE_ARRAY_MEMBER];
30+
SumLeg *positive;
31+
SumLeg *negative;
2932
} SumState;
3033

3134
static void sum_final_type(const ArgsDescriptor *args_desc, Oid *type, int32 *typmod, Oid *collid)
@@ -63,9 +66,11 @@ static AnonAggState *sum_create_state(MemoryContext memory_context, ArgsDescript
6366
MemoryContext old_context = MemoryContextSwitchTo(memory_context);
6467

6568
int trackers_count = args_desc->num_args - SUM_AIDS_OFFSET;
66-
SumState *state = palloc0(sizeof(SumState) + trackers_count * sizeof(ContributionTrackerState *));
69+
SumState *state = palloc0(sizeof(SumState));
6770
state->trackers_count = trackers_count;
6871
state->summand_type = args_desc->args[SUM_VALUE_INDEX].type_oid;
72+
state->positive = palloc0(trackers_count * sizeof(SumLeg));
73+
state->negative = palloc0(trackers_count * sizeof(SumLeg));
6974
ContributionDescriptor typed_sum_descriptor = {0};
7075
switch (state->summand_type)
7176
{
@@ -87,36 +92,54 @@ static AnonAggState *sum_create_state(MemoryContext memory_context, ArgsDescript
8792
for (int i = 0; i < trackers_count; i++)
8893
{
8994
Oid aid_type = args_desc->args[i + SUM_AIDS_OFFSET].type_oid;
90-
state->trackers[i] = contribution_tracker_new(get_aid_mapper(aid_type), &typed_sum_descriptor);
95+
state->positive[i] = contribution_tracker_new(get_aid_mapper(aid_type), &typed_sum_descriptor);
96+
state->negative[i] = contribution_tracker_new(get_aid_mapper(aid_type), &typed_sum_descriptor);
9197
}
9298

9399
MemoryContextSwitchTo(old_context);
94100
return &state->base;
95101
}
96102

97-
static SummableResultAccumulator sum_calculate_final(AnonAggState *base_state, Bucket *bucket, BucketDescriptor *bucket_desc)
103+
typedef struct SumResult
104+
{
105+
bool not_enough_aid_values;
106+
SummableResultAccumulator positive;
107+
SummableResultAccumulator negative;
108+
} SumResult;
109+
110+
static SumResult sum_calculate_final(AnonAggState *base_state, Bucket *bucket, BucketDescriptor *bucket_desc)
98111
{
99112
SumState *state = (SumState *)base_state;
100-
SummableResultAccumulator result_accumulator = {0};
113+
SummableResultAccumulator positive_result_accumulator = {0};
114+
SummableResultAccumulator negative_result_accumulator = {0};
101115
seed_t bucket_seed = compute_bucket_seed(bucket, bucket_desc);
102116

103117
for (int i = 0; i < state->trackers_count; i++)
104118
{
105-
SummableResult result = calculate_result(bucket_seed, state->trackers[i]);
119+
SummableResult positive_result = calculate_result(bucket_seed, state->positive[i]);
120+
SummableResult negative_result = calculate_result(bucket_seed, state->negative[i]);
106121

107-
accumulate_result(&result_accumulator, &result);
108-
if (result_accumulator.not_enough_aid_values)
109-
break;
122+
if (positive_result.not_enough_aid_values && negative_result.not_enough_aid_values)
123+
{
124+
return (SumResult){.not_enough_aid_values = true};
125+
}
126+
else
127+
{
128+
/* Unless both legs had `not_enough_aid_values` for given AID instance, we proceed. */
129+
accumulate_result(&positive_result_accumulator, &positive_result);
130+
accumulate_result(&negative_result_accumulator, &negative_result);
131+
}
110132
}
111-
return result_accumulator;
133+
return (SumResult){.positive = positive_result_accumulator, .negative = negative_result_accumulator};
112134
}
113135

114136
static Datum sum_finalize(AnonAggState *base_state, Bucket *bucket, BucketDescriptor *bucket_desc, bool *is_null)
115137
{
116138
SumState *state = (SumState *)base_state;
117-
SummableResultAccumulator result_accumulator = sum_calculate_final(base_state, bucket, bucket_desc);
139+
SumResult result = sum_calculate_final(base_state, bucket, bucket_desc);
118140

119-
if (result_accumulator.not_enough_aid_values)
141+
/* We deliberately ignore the `not_enough_aid_values` fields in the `result.positive` and `negative`. */
142+
if (result.not_enough_aid_values)
120143
{
121144
*is_null = true;
122145
switch (state->summand_type)
@@ -138,21 +161,22 @@ static Datum sum_finalize(AnonAggState *base_state, Bucket *bucket, BucketDescri
138161
}
139162
else
140163
{
164+
double combined_result = finalize_sum_result(&result.positive) - finalize_sum_result(&result.negative);
141165
switch (state->summand_type)
142166
{
143167
case INT2OID:
144168
case INT4OID:
145-
return Int64GetDatum((int64)round(finalize_sum_result(&result_accumulator)));
169+
return Int64GetDatum((int64)round(combined_result));
146170
case INT8OID:
147171
case NUMERICOID:
148-
return DirectFunctionCall1(float8_numeric, Float8GetDatum(finalize_sum_result(&result_accumulator)));
172+
return DirectFunctionCall1(float8_numeric, Float8GetDatum(combined_result));
149173
case FLOAT4OID:
150-
return Float4GetDatum((float4)finalize_sum_result(&result_accumulator));
174+
return Float4GetDatum((float4)combined_result);
151175
case FLOAT8OID:
152-
return Float8GetDatum(finalize_sum_result(&result_accumulator));
176+
return Float8GetDatum(combined_result);
153177
default:
154178
Assert(false);
155-
return Float8GetDatum(finalize_sum_result(&result_accumulator));
179+
return Float8GetDatum(combined_result);
156180
}
157181
}
158182
}
@@ -163,7 +187,8 @@ static void sum_merge(AnonAggState *dst_base_state, const AnonAggState *src_base
163187
const SumState *src_state = (const SumState *)src_base_state;
164188

165189
Assert(dst_state->summand_type == src_state->summand_type);
166-
merge_trackers(dst_state->trackers_count, src_state->trackers_count, dst_state->trackers, src_state->trackers);
190+
merge_trackers(dst_state->trackers_count, src_state->trackers_count, dst_state->positive, src_state->positive);
191+
merge_trackers(dst_state->trackers_count, src_state->trackers_count, dst_state->negative, src_state->negative);
167192
}
168193

169194
static contribution_t summand_to_contribution(Datum arg, Oid summand_type)
@@ -202,16 +227,26 @@ static void sum_transition(AnonAggState *base_state, int num_args, NullableDatum
202227
for (int i = 0; i < state->trackers_count; i++)
203228
{
204229
int aid_index = i + SUM_AIDS_OFFSET;
230+
ContributionDescriptor descriptor = state->positive[i]->contribution_descriptor;
231+
contribution_t abs_contribution = descriptor.contribution_abs(value_contribution);
232+
ContributionCombineFunc combine = descriptor.contribution_combine;
233+
ContributionGreaterFunc gt = descriptor.contribution_greater;
234+
ContributionEqualFunc eq = descriptor.contribution_equal;
205235

206236
if (!args[aid_index].isnull)
207237
{
208-
aid_t aid = state->trackers[i]->aid_mapper(args[aid_index].value);
209-
contribution_tracker_update_contribution(state->trackers[i], aid, value_contribution);
238+
aid_t aid = state->positive[i]->aid_mapper(args[aid_index].value);
239+
if (gt(value_contribution, descriptor.contribution_initial) || eq(value_contribution, descriptor.contribution_initial))
240+
contribution_tracker_update_contribution(state->positive[i], aid, abs_contribution);
241+
if (gt(descriptor.contribution_initial, value_contribution) || eq(value_contribution, descriptor.contribution_initial))
242+
contribution_tracker_update_contribution(state->negative[i], aid, abs_contribution);
210243
}
211244
else
212245
{
213-
ContributionCombineFunc combine = state->trackers[i]->contribution_descriptor.contribution_combine;
214-
state->trackers[i]->unaccounted_for = combine(state->trackers[i]->unaccounted_for, value_contribution);
246+
if (gt(value_contribution, descriptor.contribution_initial))
247+
state->positive[i]->unaccounted_for = combine(state->positive[i]->unaccounted_for, abs_contribution);
248+
if (gt(descriptor.contribution_initial, value_contribution))
249+
state->negative[i]->unaccounted_for = combine(state->negative[i]->unaccounted_for, abs_contribution);
215250
}
216251
}
217252
}
@@ -240,15 +275,17 @@ static void sum_noise_final_type(const ArgsDescriptor *args_desc, Oid *type, int
240275

241276
static Datum sum_noise_finalize(AnonAggState *base_state, Bucket *bucket, BucketDescriptor *bucket_desc, bool *is_null)
242277
{
243-
SummableResultAccumulator result_accumulator = sum_calculate_final(base_state, bucket, bucket_desc);
244-
if (result_accumulator.not_enough_aid_values)
278+
SumResult result = sum_calculate_final(base_state, bucket, bucket_desc);
279+
280+
/* We deliberately ignore the `not_enough_aid_values` fields in the `result.positive` and `negative`. */
281+
if (result.not_enough_aid_values)
245282
{
246283
*is_null = true;
247284
return Float8GetDatum(0.0);
248285
}
249286
else
250287
{
251-
return Float8GetDatum(finalize_noise_result(&result_accumulator));
288+
return Float8GetDatum(sqrt(pow(finalize_noise_result(&result.positive), 2) + pow(finalize_noise_result(&result.negative), 2)));
252289
}
253290
}
254291

src/aggregation/summable.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,17 @@ static double integer_contribution_to_double(contribution_t x)
2727
return (double)x.integer;
2828
}
2929

30+
static contribution_t integer_contribution_abs(contribution_t x)
31+
{
32+
return (contribution_t){.integer = labs(x.integer)};
33+
}
34+
3035
const ContributionDescriptor integer_descriptor = {
3136
.contribution_greater = integer_contribution_greater,
3237
.contribution_equal = integer_contribution_equal,
3338
.contribution_combine = integer_contribution_combine,
3439
.contribution_to_double = integer_contribution_to_double,
40+
.contribution_abs = integer_contribution_abs,
3541
.contribution_initial = {.integer = 0},
3642
};
3743
static bool real_contribution_greater(contribution_t x, contribution_t y)
@@ -54,11 +60,17 @@ static double real_contribution_to_double(contribution_t x)
5460
return (double)x.real;
5561
}
5662

63+
static contribution_t real_contribution_abs(contribution_t x)
64+
{
65+
return (contribution_t){.real = fabs(x.real)};
66+
}
67+
5768
const ContributionDescriptor real_descriptor = {
5869
.contribution_greater = real_contribution_greater,
5970
.contribution_equal = real_contribution_equal,
6071
.contribution_combine = real_contribution_combine,
6172
.contribution_to_double = real_contribution_to_double,
73+
.contribution_abs = real_contribution_abs,
6274
.contribution_initial = {.real = 0.0},
6375
};
6476

test/expected/noiseless.out

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ SET pg_diffix.outlier_count_min = 1;
66
SET pg_diffix.outlier_count_max = 1;
77
SET pg_diffix.top_count_min = 3;
88
SET pg_diffix.top_count_max = 3;
9+
-- Additional tables for SUM testing
10+
CREATE TABLE test_customers_negative AS SELECT id, city, -discount as discount, planet FROM test_customers;
11+
CREATE TABLE test_customers_mixed AS SELECT id, city, discount - 1.0 as discount, planet FROM test_customers;
12+
CALL diffix.mark_personal('public.test_customers_negative', 'id');
13+
CALL diffix.mark_personal('public.test_customers_mixed', 'id');
914
SET ROLE diffix_test;
1015
SET pg_diffix.session_access_level = 'anonymized_trusted';
1116
----------------------------------------------------------------
@@ -89,6 +94,18 @@ SELECT city, SUM(discount), diffix.sum_noise(discount) FROM test_customers GROUP
8994
Berlin | 9 | 0
9095
(3 rows)
9196

97+
SELECT SUM(discount), diffix.sum_noise(discount) FROM test_customers_negative;
98+
sum | sum_noise
99+
-----+-----------
100+
-19 | 0
101+
(1 row)
102+
103+
SELECT SUM(discount), diffix.sum_noise(discount) FROM test_customers_mixed;
104+
sum | sum_noise
105+
-----+-----------
106+
2.5 | 0
107+
(1 row)
108+
92109
-- sum supports numeric type
93110
SELECT city, SUM(discount::numeric), pg_typeof(SUM(discount::numeric)), diffix.sum_noise(discount::numeric)
94111
FROM test_customers

test/expected/validation.out

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,11 @@ SELECT * FROM diffix.show_settings() LIMIT 2;
294294
pg_diffix.default_access_level | direct | Access level for unlabeled users.
295295
(2 rows)
296296

297-
SELECT * FROM diffix.show_labels() WHERE objname LIKE 'public.test_customers%';
298-
objtype | objname | label
299-
---------+--------------------------+----------
300-
table | public.test_customers | personal
301-
column | public.test_customers.id | aid
297+
SELECT * FROM diffix.show_labels() WHERE objname LIKE 'public.empty_test_customers%';
298+
objtype | objname | label
299+
---------+--------------------------------+----------
300+
table | public.empty_test_customers | personal
301+
column | public.empty_test_customers.id | aid
302302
(2 rows)
303303

304304
-- Allow prepared statements

test/sql/noiseless.sql

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ SET pg_diffix.outlier_count_max = 1;
88
SET pg_diffix.top_count_min = 3;
99
SET pg_diffix.top_count_max = 3;
1010

11+
-- Additional tables for SUM testing
12+
CREATE TABLE test_customers_negative AS SELECT id, city, -discount as discount, planet FROM test_customers;
13+
CREATE TABLE test_customers_mixed AS SELECT id, city, discount - 1.0 as discount, planet FROM test_customers;
14+
CALL diffix.mark_personal('public.test_customers_negative', 'id');
15+
CALL diffix.mark_personal('public.test_customers_mixed', 'id');
16+
1117
SET ROLE diffix_test;
1218
SET pg_diffix.session_access_level = 'anonymized_trusted';
1319

@@ -41,6 +47,9 @@ SELECT SUM(discount), diffix.sum_noise(discount) FROM test_customers;
4147
SELECT city, SUM(id), diffix.sum_noise(id) FROM test_customers GROUP BY 1;
4248
SELECT city, SUM(discount), diffix.sum_noise(discount) FROM test_customers GROUP BY 1;
4349

50+
SELECT SUM(discount), diffix.sum_noise(discount) FROM test_customers_negative;
51+
SELECT SUM(discount), diffix.sum_noise(discount) FROM test_customers_mixed;
52+
4453
-- sum supports numeric type
4554
SELECT city, SUM(discount::numeric), pg_typeof(SUM(discount::numeric)), diffix.sum_noise(discount::numeric)
4655
FROM test_customers

test/sql/validation.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ SELECT EXISTS (SELECT FROM Information_Schema.tables WHERE table_schema='public'
150150

151151
-- Settings and labels UDFs work
152152
SELECT * FROM diffix.show_settings() LIMIT 2;
153-
SELECT * FROM diffix.show_labels() WHERE objname LIKE 'public.test_customers%';
153+
SELECT * FROM diffix.show_labels() WHERE objname LIKE 'public.empty_test_customers%';
154154

155155
-- Allow prepared statements
156156
PREPARE prepared(float) AS SELECT discount, count(*) FROM empty_test_customers WHERE discount = $1 GROUP BY 1;

0 commit comments

Comments
 (0)