@@ -20,12 +20,15 @@ static double finalize_sum_result(const SummableResultAccumulator *accumulator)
2020 *-------------------------------------------------------------------------
2121 */
2222
23+ typedef ContributionTrackerState * SumLeg ;
24+
2325typedef struct SumState
2426{
2527 AnonAggState base ;
2628 int trackers_count ;
2729 Oid summand_type ;
28- ContributionTrackerState * trackers [FLEXIBLE_ARRAY_MEMBER ];
30+ SumLeg * positive ;
31+ SumLeg * negative ;
2932} SumState ;
3033
3134static void sum_final_type (const ArgsDescriptor * args_desc , Oid * type , int32 * typmod , Oid * collid )
@@ -63,9 +66,11 @@ static AnonAggState *sum_create_state(MemoryContext memory_context, ArgsDescript
6366 MemoryContext old_context = MemoryContextSwitchTo (memory_context );
6467
6568 int trackers_count = args_desc -> num_args - SUM_AIDS_OFFSET ;
66- SumState * state = palloc0 (sizeof (SumState ) + trackers_count * sizeof ( ContributionTrackerState * ) );
69+ SumState * state = palloc0 (sizeof (SumState ));
6770 state -> trackers_count = trackers_count ;
6871 state -> summand_type = args_desc -> args [SUM_VALUE_INDEX ].type_oid ;
72+ state -> positive = palloc0 (trackers_count * sizeof (ContributionTrackerState * ));
73+ state -> negative = palloc0 (trackers_count * sizeof (ContributionTrackerState * ));
6974 ContributionDescriptor typed_sum_descriptor = {0 };
7075 switch (state -> summand_type )
7176 {
@@ -87,36 +92,54 @@ static AnonAggState *sum_create_state(MemoryContext memory_context, ArgsDescript
8792 for (int i = 0 ; i < trackers_count ; i ++ )
8893 {
8994 Oid aid_type = args_desc -> args [i + SUM_AIDS_OFFSET ].type_oid ;
90- state -> trackers [i ] = contribution_tracker_new (get_aid_mapper (aid_type ), & typed_sum_descriptor );
95+ state -> positive [i ] = contribution_tracker_new (get_aid_mapper (aid_type ), & typed_sum_descriptor );
96+ state -> negative [i ] = contribution_tracker_new (get_aid_mapper (aid_type ), & typed_sum_descriptor );
9197 }
9298
9399 MemoryContextSwitchTo (old_context );
94100 return & state -> base ;
95101}
96102
97- static SummableResultAccumulator sum_calculate_final (AnonAggState * base_state , Bucket * bucket , BucketDescriptor * bucket_desc )
103+ typedef struct SumResultAccumulators
104+ {
105+ bool not_enough_aid_values ;
106+ SummableResultAccumulator positive ;
107+ SummableResultAccumulator negative ;
108+ } SumResultAccumulators ;
109+
110+ static SumResultAccumulators sum_calculate_final (AnonAggState * base_state , Bucket * bucket , BucketDescriptor * bucket_desc )
98111{
99112 SumState * state = (SumState * )base_state ;
100- SummableResultAccumulator result_accumulator = {0 };
113+ SummableResultAccumulator positive_result_accumulator = {0 };
114+ SummableResultAccumulator negative_result_accumulator = {0 };
101115 seed_t bucket_seed = compute_bucket_seed (bucket , bucket_desc );
102116
103117 for (int i = 0 ; i < state -> trackers_count ; i ++ )
104118 {
105- SummableResult result = calculate_result (bucket_seed , state -> trackers [i ]);
119+ SummableResult positive_result = calculate_result (bucket_seed , state -> positive [i ]);
120+ SummableResult negative_result = calculate_result (bucket_seed , state -> negative [i ]);
106121
107- accumulate_result (& result_accumulator , & result );
108- if (result_accumulator .not_enough_aid_values )
109- break ;
122+ if (positive_result .not_enough_aid_values && negative_result .not_enough_aid_values )
123+ {
124+ return (SumResultAccumulators ){.not_enough_aid_values = true};
125+ }
126+ else
127+ {
128+ /* Unless both legs had `not_enough_aid_values` for given AID instance, we proceed. */
129+ accumulate_result (& positive_result_accumulator , & positive_result );
130+ accumulate_result (& negative_result_accumulator , & negative_result );
131+ }
110132 }
111- return result_accumulator ;
133+ return ( SumResultAccumulators ){. positive = positive_result_accumulator , . negative = negative_result_accumulator } ;
112134}
113135
114136static Datum sum_finalize (AnonAggState * base_state , Bucket * bucket , BucketDescriptor * bucket_desc , bool * is_null )
115137{
116138 SumState * state = (SumState * )base_state ;
117- SummableResultAccumulator result_accumulator = sum_calculate_final (base_state , bucket , bucket_desc );
139+ SumResultAccumulators results = sum_calculate_final (base_state , bucket , bucket_desc );
118140
119- if (result_accumulator .not_enough_aid_values )
141+ /* We deliberately ignore the `not_enough_aid_values` fields in the `results.positive` and `negative`. */
142+ if (results .not_enough_aid_values )
120143 {
121144 * is_null = true;
122145 switch (state -> summand_type )
@@ -138,21 +161,22 @@ static Datum sum_finalize(AnonAggState *base_state, Bucket *bucket, BucketDescri
138161 }
139162 else
140163 {
164+ double combined_result = finalize_sum_result (& results .positive ) - finalize_sum_result (& results .negative );
141165 switch (state -> summand_type )
142166 {
143167 case INT2OID :
144168 case INT4OID :
145- return Int64GetDatum ((int64 )round (finalize_sum_result ( & result_accumulator ) ));
169+ return Int64GetDatum ((int64 )round (combined_result ));
146170 case INT8OID :
147171 case NUMERICOID :
148- return DirectFunctionCall1 (float8_numeric , Float8GetDatum (finalize_sum_result ( & result_accumulator ) ));
172+ return DirectFunctionCall1 (float8_numeric , Float8GetDatum (combined_result ));
149173 case FLOAT4OID :
150- return Float4GetDatum ((float4 )finalize_sum_result ( & result_accumulator ) );
174+ return Float4GetDatum ((float4 )combined_result );
151175 case FLOAT8OID :
152- return Float8GetDatum (finalize_sum_result ( & result_accumulator ) );
176+ return Float8GetDatum (combined_result );
153177 default :
154178 Assert (false);
155- return Float8GetDatum (finalize_sum_result ( & result_accumulator ) );
179+ return Float8GetDatum (combined_result );
156180 }
157181 }
158182}
@@ -163,7 +187,8 @@ static void sum_merge(AnonAggState *dst_base_state, const AnonAggState *src_base
163187 const SumState * src_state = (const SumState * )src_base_state ;
164188
165189 Assert (dst_state -> summand_type == src_state -> summand_type );
166- merge_trackers (dst_state -> trackers_count , src_state -> trackers_count , dst_state -> trackers , src_state -> trackers );
190+ merge_trackers (dst_state -> trackers_count , src_state -> trackers_count , dst_state -> positive , src_state -> positive );
191+ merge_trackers (dst_state -> trackers_count , src_state -> trackers_count , dst_state -> negative , src_state -> negative );
167192}
168193
169194static contribution_t summand_to_contribution (Datum arg , Oid summand_type )
@@ -202,16 +227,26 @@ static void sum_transition(AnonAggState *base_state, int num_args, NullableDatum
202227 for (int i = 0 ; i < state -> trackers_count ; i ++ )
203228 {
204229 int aid_index = i + SUM_AIDS_OFFSET ;
230+ ContributionDescriptor descriptor = state -> positive [i ]-> contribution_descriptor ;
231+ contribution_t abs_contribution = descriptor .contribution_abs (value_contribution );
232+ ContributionCombineFunc combine = descriptor .contribution_combine ;
233+ ContributionGreaterFunc gt = descriptor .contribution_greater ;
234+ ContributionEqualFunc eq = descriptor .contribution_equal ;
205235
206236 if (!args [aid_index ].isnull )
207237 {
208- aid_t aid = state -> trackers [i ]-> aid_mapper (args [aid_index ].value );
209- contribution_tracker_update_contribution (state -> trackers [i ], aid , value_contribution );
238+ aid_t aid = state -> positive [i ]-> aid_mapper (args [aid_index ].value );
239+ if (gt (value_contribution , descriptor .contribution_initial ) || eq (value_contribution , descriptor .contribution_initial ))
240+ contribution_tracker_update_contribution (state -> positive [i ], aid , abs_contribution );
241+ if (gt (descriptor .contribution_initial , value_contribution ) || eq (value_contribution , descriptor .contribution_initial ))
242+ contribution_tracker_update_contribution (state -> negative [i ], aid , abs_contribution );
210243 }
211244 else
212245 {
213- ContributionCombineFunc combine = state -> trackers [i ]-> contribution_descriptor .contribution_combine ;
214- state -> trackers [i ]-> unaccounted_for = combine (state -> trackers [i ]-> unaccounted_for , value_contribution );
246+ if (gt (value_contribution , descriptor .contribution_initial ))
247+ state -> positive [i ]-> unaccounted_for = combine (state -> positive [i ]-> unaccounted_for , abs_contribution );
248+ if (gt (descriptor .contribution_initial , value_contribution ))
249+ state -> negative [i ]-> unaccounted_for = combine (state -> negative [i ]-> unaccounted_for , abs_contribution );
215250 }
216251 }
217252 }
@@ -240,15 +275,17 @@ static void sum_noise_final_type(const ArgsDescriptor *args_desc, Oid *type, int
240275
241276static Datum sum_noise_finalize (AnonAggState * base_state , Bucket * bucket , BucketDescriptor * bucket_desc , bool * is_null )
242277{
243- SummableResultAccumulator result_accumulator = sum_calculate_final (base_state , bucket , bucket_desc );
244- if (result_accumulator .not_enough_aid_values )
278+ SumResultAccumulators results = sum_calculate_final (base_state , bucket , bucket_desc );
279+
280+ /* We deliberately ignore the `not_enough_aid_values` fields in the `results.positive` and `negative`. */
281+ if (results .not_enough_aid_values )
245282 {
246283 * is_null = true;
247284 return Float8GetDatum (0.0 );
248285 }
249286 else
250287 {
251- return Float8GetDatum (finalize_noise_result (& result_accumulator ));
288+ return Float8GetDatum (sqrt ( pow ( finalize_noise_result (& results . positive ), 2 ) + pow ( finalize_noise_result ( & results . negative ), 2 ) ));
252289 }
253290}
254291
0 commit comments