@@ -35,7 +35,12 @@ import (
35
35
// It contains the opcode and input column number.
36
36
type AggregateParams struct {
37
37
Opcode opcode.AggregateOpcode
38
- Col int
38
+
39
+ // Input source specification - exactly one of these should be set:
40
+ // Col: Column index for simple column references (e.g., SUM(column_name))
41
+ // EExpr: Evaluated expression for literals, parameters
42
+ Col int
43
+ EExpr evalengine.Expr
39
44
40
45
// These are used only for distinct opcodes.
41
46
KeyCol int
@@ -53,15 +58,26 @@ type AggregateParams struct {
53
58
CollationEnv * collations.Environment
54
59
}
55
60
56
- func NewAggregateParam (opcode opcode.AggregateOpcode , col int , alias string , collationEnv * collations.Environment ) * AggregateParams {
61
+ // NewAggregateParam creates a new aggregate param
62
+ func NewAggregateParam (
63
+ oc opcode.AggregateOpcode ,
64
+ col int ,
65
+ expr evalengine.Expr ,
66
+ alias string ,
67
+ collationEnv * collations.Environment ,
68
+ ) * AggregateParams {
69
+ if expr != nil && oc != opcode .AggregateConstant {
70
+ panic (vterrors .VT13001 ("expr should be nil" ))
71
+ }
57
72
out := & AggregateParams {
58
- Opcode : opcode ,
73
+ Opcode : oc ,
59
74
Col : col ,
75
+ EExpr : expr ,
60
76
Alias : alias ,
61
77
WCol : - 1 ,
62
78
CollationEnv : collationEnv ,
63
79
}
64
- if opcode .NeedsComparableValues () {
80
+ if oc .NeedsComparableValues () {
65
81
out .KeyCol = col
66
82
}
67
83
return out
@@ -73,6 +89,9 @@ func (ap *AggregateParams) WAssigned() bool {
73
89
74
90
func (ap * AggregateParams ) String () string {
75
91
keyCol := strconv .Itoa (ap .Col )
92
+ if ap .EExpr != nil {
93
+ keyCol = sqlparser .String (ap .EExpr )
94
+ }
76
95
if ap .WAssigned () {
77
96
keyCol = fmt .Sprintf ("%s|%d" , keyCol , ap .WCol )
78
97
}
@@ -89,7 +108,14 @@ func (ap *AggregateParams) String() string {
89
108
return fmt .Sprintf ("%s%s(%s)" , ap .Opcode .String (), dispOrigOp , keyCol )
90
109
}
91
110
92
- func (ap * AggregateParams ) typ (inputType querypb.Type ) querypb.Type {
111
+ func (ap * AggregateParams ) typ (inputType querypb.Type , env * evalengine.ExpressionEnv , collID collations.ID ) querypb.Type {
112
+ if ap .EExpr != nil {
113
+ value , err := eval (env , ap .EExpr , collID )
114
+ if err != nil {
115
+ return sqltypes .Unknown
116
+ }
117
+ return value .Type ()
118
+ }
93
119
if ap .OrigOpcode != opcode .AggregateUnassigned {
94
120
return ap .OrigOpcode .SQLType (inputType )
95
121
}
@@ -98,7 +124,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
98
124
99
125
type aggregator interface {
100
126
add (row []sqltypes.Value ) error
101
- finish () sqltypes.Value
127
+ finish (env * evalengine. ExpressionEnv , coll collations. ID ) ( sqltypes.Value , error )
102
128
reset ()
103
129
}
104
130
@@ -151,8 +177,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error {
151
177
return nil
152
178
}
153
179
154
- func (a * aggregatorCount ) finish () sqltypes.Value {
155
- return sqltypes .NewInt64 (a .n )
180
+ func (a * aggregatorCount ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
181
+ return sqltypes .NewInt64 (a .n ), nil
156
182
}
157
183
158
184
func (a * aggregatorCount ) reset () {
@@ -164,13 +190,13 @@ type aggregatorCountStar struct {
164
190
n int64
165
191
}
166
192
167
- func (a * aggregatorCountStar ) add (_ []sqltypes.Value ) error {
193
+ func (a * aggregatorCountStar ) add ([]sqltypes.Value ) error {
168
194
a .n ++
169
195
return nil
170
196
}
171
197
172
- func (a * aggregatorCountStar ) finish () sqltypes.Value {
173
- return sqltypes .NewInt64 (a .n )
198
+ func (a * aggregatorCountStar ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
199
+ return sqltypes .NewInt64 (a .n ), nil
174
200
}
175
201
176
202
func (a * aggregatorCountStar ) reset () {
@@ -198,8 +224,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) {
198
224
return a .minmax .Max (row [a .from ])
199
225
}
200
226
201
- func (a * aggregatorMinMax ) finish () sqltypes.Value {
202
- return a .minmax .Result ()
227
+ func (a * aggregatorMinMax ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
228
+ return a .minmax .Result (), nil
203
229
}
204
230
205
231
func (a * aggregatorMinMax ) reset () {
@@ -222,8 +248,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error {
222
248
return a .sum .Add (row [a .from ])
223
249
}
224
250
225
- func (a * aggregatorSum ) finish () sqltypes.Value {
226
- return a .sum .Result ()
251
+ func (a * aggregatorSum ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
252
+ return a .sum .Result (), nil
227
253
}
228
254
229
255
func (a * aggregatorSum ) reset () {
@@ -232,28 +258,51 @@ func (a *aggregatorSum) reset() {
232
258
}
233
259
234
260
type aggregatorScalar struct {
235
- from int
236
- current sqltypes.Value
237
- init bool
261
+ from int
262
+ current sqltypes.Value
263
+ hasValue bool
238
264
}
239
265
240
266
func (a * aggregatorScalar ) add (row []sqltypes.Value ) error {
241
- if ! a .init {
267
+ if ! a .hasValue {
242
268
a .current = row [a .from ]
243
- a .init = true
269
+ a .hasValue = true
244
270
}
245
271
return nil
246
272
}
247
273
248
- func (a * aggregatorScalar ) finish () sqltypes.Value {
249
- return a .current
274
+ func (a * aggregatorScalar ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
275
+ return a .current , nil
250
276
}
251
277
252
278
func (a * aggregatorScalar ) reset () {
253
279
a .current = sqltypes .NULL
254
- a .init = false
280
+ a .hasValue = false
281
+ }
282
+
283
+ type aggregatorConstant struct {
284
+ expr evalengine.Expr
285
+ }
286
+
287
+ func (* aggregatorConstant ) add ([]sqltypes.Value ) error {
288
+ return nil
255
289
}
256
290
291
+ func (a * aggregatorConstant ) finish (env * evalengine.ExpressionEnv , coll collations.ID ) (sqltypes.Value , error ) {
292
+ return eval (env , a .expr , coll )
293
+ }
294
+
295
+ func eval (env * evalengine.ExpressionEnv , eexpr evalengine.Expr , coll collations.ID ) (sqltypes.Value , error ) {
296
+ v , err := env .Evaluate (eexpr )
297
+ if err != nil {
298
+ return sqltypes.Value {}, err
299
+ }
300
+
301
+ return v .Value (coll ), nil
302
+ }
303
+
304
+ func (* aggregatorConstant ) reset () {}
305
+
257
306
type aggregatorGroupConcat struct {
258
307
from int
259
308
type_ sqltypes.Type
@@ -275,11 +324,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error {
275
324
return nil
276
325
}
277
326
278
- func (a * aggregatorGroupConcat ) finish () sqltypes.Value {
327
+ func (a * aggregatorGroupConcat ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
279
328
if a .n == 0 {
280
- return sqltypes .NULL
329
+ return sqltypes .NULL , nil
281
330
}
282
- return sqltypes .MakeTrusted (a .type_ , a .concat )
331
+ return sqltypes .MakeTrusted (a .type_ , a .concat ), nil
283
332
}
284
333
285
334
func (a * aggregatorGroupConcat ) reset () {
@@ -301,36 +350,44 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error {
301
350
return nil
302
351
}
303
352
304
- func (a * aggregatorGtid ) finish () sqltypes.Value {
353
+ func (a * aggregatorGtid ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
305
354
gtid := binlogdatapb.VGtid {ShardGtids : a .shards }
306
- return sqltypes .NewVarChar (gtid .String ())
355
+ return sqltypes .NewVarChar (gtid .String ()), nil
307
356
}
308
357
309
358
func (a * aggregatorGtid ) reset () {
310
359
a .shards = a .shards [:0 ] // safe to reuse because only the serialized form of a.shards is returned
311
360
}
312
361
313
- type aggregationState []aggregator
362
+ type aggregationState struct {
363
+ env * evalengine.ExpressionEnv
364
+ aggregators []aggregator
365
+ coll collations.ID
366
+ }
314
367
315
- func (a aggregationState ) add (row []sqltypes.Value ) error {
316
- for _ , st := range a {
368
+ func (a * aggregationState ) add (row []sqltypes.Value ) error {
369
+ for _ , st := range a . aggregators {
317
370
if err := st .add (row ); err != nil {
318
371
return err
319
372
}
320
373
}
321
374
return nil
322
375
}
323
376
324
- func (a aggregationState ) finish () (row []sqltypes.Value ) {
325
- row = make ([]sqltypes.Value , 0 , len (a ))
326
- for _ , st := range a {
327
- row = append (row , st .finish ())
377
+ func (a * aggregationState ) finish () ([]sqltypes.Value , error ) {
378
+ row := make ([]sqltypes.Value , 0 , len (a .aggregators ))
379
+ for _ , st := range a .aggregators {
380
+ v , err := st .finish (a .env , a .coll )
381
+ if err != nil {
382
+ return nil , err
383
+ }
384
+ row = append (row , v )
328
385
}
329
- return
386
+ return row , nil
330
387
}
331
388
332
- func (a aggregationState ) reset () {
333
- for _ , st := range a {
389
+ func (a * aggregationState ) reset () {
390
+ for _ , st := range a . aggregators {
334
391
st .reset ()
335
392
}
336
393
}
@@ -354,13 +411,16 @@ func isComparable(typ sqltypes.Type) bool {
354
411
return false
355
412
}
356
413
357
- func newAggregation (fields []* querypb.Field , aggregates []* AggregateParams ) (aggregationState , []* querypb.Field , error ) {
414
+ func newAggregation (fields []* querypb.Field , aggregates []* AggregateParams , env * evalengine. ExpressionEnv , collation collations. ID ) (* aggregationState , []* querypb.Field , error ) {
358
415
fields = slice .Map (fields , func (from * querypb.Field ) * querypb.Field { return from .CloneVT () })
359
416
360
- agstate := make ([]aggregator , len (fields ))
417
+ aggregators := make ([]aggregator , len (fields ))
361
418
for _ , aggr := range aggregates {
362
- sourceType := fields [aggr .Col ].Type
363
- targetType := aggr .typ (sourceType )
419
+ var sourceType querypb.Type
420
+ if aggr .Col < len (fields ) {
421
+ sourceType = fields [aggr .Col ].Type
422
+ }
423
+ targetType := aggr .typ (sourceType , env , collation )
364
424
365
425
var ag aggregator
366
426
var distinct = - 1
@@ -444,22 +504,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
444
504
separator : separator ,
445
505
}
446
506
507
+ case opcode .AggregateConstant :
508
+ ag = & aggregatorConstant {expr : aggr .EExpr }
509
+
447
510
default :
448
511
panic ("BUG: unexpected Aggregation opcode" )
449
512
}
450
513
451
- agstate [aggr .Col ] = ag
514
+ aggregators [aggr .Col ] = ag
452
515
fields [aggr .Col ].Type = targetType
453
516
if aggr .Alias != "" {
454
517
fields [aggr .Col ].Name = aggr .Alias
455
518
}
456
519
}
457
520
458
- for i , a := range agstate {
521
+ for i , a := range aggregators {
459
522
if a == nil {
460
- agstate [i ] = & aggregatorScalar {from : i }
523
+ aggregators [i ] = & aggregatorScalar {from : i }
461
524
}
462
525
}
463
526
464
- return agstate , fields , nil
527
+ return & aggregationState { aggregators : aggregators , env : env , coll : collation } , fields , nil
465
528
}
0 commit comments