@@ -35,7 +35,11 @@ import (
35
35
// It contains the opcode and input column number.
36
36
type AggregateParams struct {
37
37
Opcode AggregateOpcode
38
- Col int
38
+ // Input source specification - exactly one of these should be set:
39
+ // Col: Column index for simple column references (e.g., SUM(column_name))
40
+ // EExpr: Evaluated expression for literals, parameters
41
+ Col int
42
+ EExpr evalengine.Expr
39
43
40
44
// These are used only for distinct opcodes.
41
45
KeyCol int
@@ -53,15 +57,26 @@ type AggregateParams struct {
53
57
CollationEnv * collations.Environment
54
58
}
55
59
56
- func NewAggregateParam (opcode AggregateOpcode , col int , alias string , collationEnv * collations.Environment ) * AggregateParams {
60
+ // NewAggregateParam creates a new aggregate param
61
+ func NewAggregateParam (
62
+ oc AggregateOpcode ,
63
+ col int ,
64
+ expr evalengine.Expr ,
65
+ alias string ,
66
+ collationEnv * collations.Environment ,
67
+ ) * AggregateParams {
68
+ if expr != nil && oc != AggregateConstant {
69
+ panic (vterrors .VT13001 ("expr should be nil" ))
70
+ }
57
71
out := & AggregateParams {
58
- Opcode : opcode ,
72
+ Opcode : oc ,
59
73
Col : col ,
74
+ EExpr : expr ,
60
75
Alias : alias ,
61
76
WCol : - 1 ,
62
77
CollationEnv : collationEnv ,
63
78
}
64
- if opcode .NeedsComparableValues () {
79
+ if oc .NeedsComparableValues () {
65
80
out .KeyCol = col
66
81
}
67
82
return out
@@ -73,6 +88,9 @@ func (ap *AggregateParams) WAssigned() bool {
73
88
74
89
func (ap * AggregateParams ) String () string {
75
90
keyCol := strconv .Itoa (ap .Col )
91
+ if ap .EExpr != nil {
92
+ keyCol = sqlparser .String (ap .EExpr )
93
+ }
76
94
if ap .WAssigned () {
77
95
keyCol = fmt .Sprintf ("%s|%d" , keyCol , ap .WCol )
78
96
}
@@ -89,7 +107,14 @@ func (ap *AggregateParams) String() string {
89
107
return fmt .Sprintf ("%s%s(%s)" , ap .Opcode .String (), dispOrigOp , keyCol )
90
108
}
91
109
92
- func (ap * AggregateParams ) typ (inputType querypb.Type ) querypb.Type {
110
+ func (ap * AggregateParams ) typ (inputType querypb.Type , env * evalengine.ExpressionEnv , collID collations.ID ) querypb.Type {
111
+ if ap .EExpr != nil {
112
+ value , err := eval (env , ap .EExpr , collID )
113
+ if err != nil {
114
+ return sqltypes .Unknown
115
+ }
116
+ return value .Type ()
117
+ }
93
118
if ap .OrigOpcode != AggregateUnassigned {
94
119
return ap .OrigOpcode .SQLType (inputType )
95
120
}
@@ -98,7 +123,7 @@ func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
98
123
99
124
type aggregator interface {
100
125
add (row []sqltypes.Value ) error
101
- finish () sqltypes.Value
126
+ finish (env * evalengine. ExpressionEnv , coll collations. ID ) ( sqltypes.Value , error )
102
127
reset ()
103
128
}
104
129
@@ -151,8 +176,8 @@ func (a *aggregatorCount) add(row []sqltypes.Value) error {
151
176
return nil
152
177
}
153
178
154
- func (a * aggregatorCount ) finish () sqltypes.Value {
155
- return sqltypes .NewInt64 (a .n )
179
+ func (a * aggregatorCount ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
180
+ return sqltypes .NewInt64 (a .n ), nil
156
181
}
157
182
158
183
func (a * aggregatorCount ) reset () {
@@ -164,13 +189,13 @@ type aggregatorCountStar struct {
164
189
n int64
165
190
}
166
191
167
- func (a * aggregatorCountStar ) add (_ []sqltypes.Value ) error {
192
+ func (a * aggregatorCountStar ) add ([]sqltypes.Value ) error {
168
193
a .n ++
169
194
return nil
170
195
}
171
196
172
- func (a * aggregatorCountStar ) finish () sqltypes.Value {
173
- return sqltypes .NewInt64 (a .n )
197
+ func (a * aggregatorCountStar ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
198
+ return sqltypes .NewInt64 (a .n ), nil
174
199
}
175
200
176
201
func (a * aggregatorCountStar ) reset () {
@@ -198,8 +223,8 @@ func (a *aggregatorMax) add(row []sqltypes.Value) (err error) {
198
223
return a .minmax .Max (row [a .from ])
199
224
}
200
225
201
- func (a * aggregatorMinMax ) finish () sqltypes.Value {
202
- return a .minmax .Result ()
226
+ func (a * aggregatorMinMax ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
227
+ return a .minmax .Result (), nil
203
228
}
204
229
205
230
func (a * aggregatorMinMax ) reset () {
@@ -222,8 +247,8 @@ func (a *aggregatorSum) add(row []sqltypes.Value) error {
222
247
return a .sum .Add (row [a .from ])
223
248
}
224
249
225
- func (a * aggregatorSum ) finish () sqltypes.Value {
226
- return a .sum .Result ()
250
+ func (a * aggregatorSum ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
251
+ return a .sum .Result (), nil
227
252
}
228
253
229
254
func (a * aggregatorSum ) reset () {
@@ -232,28 +257,51 @@ func (a *aggregatorSum) reset() {
232
257
}
233
258
234
259
type aggregatorScalar struct {
235
- from int
236
- current sqltypes.Value
237
- init bool
260
+ from int
261
+ current sqltypes.Value
262
+ hasValue bool
238
263
}
239
264
240
265
func (a * aggregatorScalar ) add (row []sqltypes.Value ) error {
241
- if ! a .init {
266
+ if ! a .hasValue {
242
267
a .current = row [a .from ]
243
- a .init = true
268
+ a .hasValue = true
244
269
}
245
270
return nil
246
271
}
247
272
248
- func (a * aggregatorScalar ) finish () sqltypes.Value {
249
- return a .current
273
+ func (a * aggregatorScalar ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
274
+ return a .current , nil
250
275
}
251
276
252
277
func (a * aggregatorScalar ) reset () {
253
278
a .current = sqltypes .NULL
254
- a .init = false
279
+ a .hasValue = false
255
280
}
256
281
282
+ type aggregatorConstant struct {
283
+ expr evalengine.Expr
284
+ }
285
+
286
+ func (* aggregatorConstant ) add ([]sqltypes.Value ) error {
287
+ return nil
288
+ }
289
+
290
+ func (a * aggregatorConstant ) finish (env * evalengine.ExpressionEnv , coll collations.ID ) (sqltypes.Value , error ) {
291
+ return eval (env , a .expr , coll )
292
+ }
293
+
294
+ func eval (env * evalengine.ExpressionEnv , eexpr evalengine.Expr , coll collations.ID ) (sqltypes.Value , error ) {
295
+ v , err := env .Evaluate (eexpr )
296
+ if err != nil {
297
+ return sqltypes.Value {}, err
298
+ }
299
+
300
+ return v .Value (coll ), nil
301
+ }
302
+
303
+ func (* aggregatorConstant ) reset () {}
304
+
257
305
type aggregatorGroupConcat struct {
258
306
from int
259
307
type_ sqltypes.Type
@@ -275,11 +323,11 @@ func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error {
275
323
return nil
276
324
}
277
325
278
- func (a * aggregatorGroupConcat ) finish () sqltypes.Value {
326
+ func (a * aggregatorGroupConcat ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
279
327
if a .n == 0 {
280
- return sqltypes .NULL
328
+ return sqltypes .NULL , nil
281
329
}
282
- return sqltypes .MakeTrusted (a .type_ , a .concat )
330
+ return sqltypes .MakeTrusted (a .type_ , a .concat ), nil
283
331
}
284
332
285
333
func (a * aggregatorGroupConcat ) reset () {
@@ -301,36 +349,44 @@ func (a *aggregatorGtid) add(row []sqltypes.Value) error {
301
349
return nil
302
350
}
303
351
304
- func (a * aggregatorGtid ) finish () sqltypes.Value {
352
+ func (a * aggregatorGtid ) finish (* evalengine. ExpressionEnv , collations. ID ) ( sqltypes.Value , error ) {
305
353
gtid := binlogdatapb.VGtid {ShardGtids : a .shards }
306
- return sqltypes .NewVarChar (gtid .String ())
354
+ return sqltypes .NewVarChar (gtid .String ()), nil
307
355
}
308
356
309
357
func (a * aggregatorGtid ) reset () {
310
358
a .shards = a .shards [:0 ] // safe to reuse because only the serialized form of a.shards is returned
311
359
}
312
360
313
- type aggregationState []aggregator
361
+ type aggregationState struct {
362
+ env * evalengine.ExpressionEnv
363
+ aggregators []aggregator
364
+ coll collations.ID
365
+ }
314
366
315
- func (a aggregationState ) add (row []sqltypes.Value ) error {
316
- for _ , st := range a {
367
+ func (a * aggregationState ) add (row []sqltypes.Value ) error {
368
+ for _ , st := range a . aggregators {
317
369
if err := st .add (row ); err != nil {
318
370
return err
319
371
}
320
372
}
321
373
return nil
322
374
}
323
375
324
- func (a aggregationState ) finish () (row []sqltypes.Value ) {
325
- row = make ([]sqltypes.Value , 0 , len (a ))
326
- for _ , st := range a {
327
- row = append (row , st .finish ())
376
+ func (a * aggregationState ) finish () ([]sqltypes.Value , error ) {
377
+ row := make ([]sqltypes.Value , 0 , len (a .aggregators ))
378
+ for _ , st := range a .aggregators {
379
+ v , err := st .finish (a .env , a .coll )
380
+ if err != nil {
381
+ return nil , err
382
+ }
383
+ row = append (row , v )
328
384
}
329
- return
385
+ return row , nil
330
386
}
331
387
332
- func (a aggregationState ) reset () {
333
- for _ , st := range a {
388
+ func (a * aggregationState ) reset () {
389
+ for _ , st := range a . aggregators {
334
390
st .reset ()
335
391
}
336
392
}
@@ -354,13 +410,16 @@ func isComparable(typ sqltypes.Type) bool {
354
410
return false
355
411
}
356
412
357
- func newAggregation (fields []* querypb.Field , aggregates []* AggregateParams ) (aggregationState , []* querypb.Field , error ) {
413
+ func newAggregation (fields []* querypb.Field , aggregates []* AggregateParams , env * evalengine. ExpressionEnv , collation collations. ID ) (* aggregationState , []* querypb.Field , error ) {
358
414
fields = slice .Map (fields , func (from * querypb.Field ) * querypb.Field { return from .CloneVT () })
359
415
360
- agstate := make ([]aggregator , len (fields ))
416
+ aggregators := make ([]aggregator , len (fields ))
361
417
for _ , aggr := range aggregates {
362
- sourceType := fields [aggr .Col ].Type
363
- targetType := aggr .typ (sourceType )
418
+ var sourceType querypb.Type
419
+ if aggr .Col < len (fields ) {
420
+ sourceType = fields [aggr .Col ].Type
421
+ }
422
+ targetType := aggr .typ (sourceType , env , collation )
364
423
365
424
var ag aggregator
366
425
var distinct = - 1
@@ -444,22 +503,25 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
444
503
separator : separator ,
445
504
}
446
505
506
+ case AggregateConstant :
507
+ ag = & aggregatorConstant {expr : aggr .EExpr }
508
+
447
509
default :
448
510
panic ("BUG: unexpected Aggregation opcode" )
449
511
}
450
512
451
- agstate [aggr .Col ] = ag
513
+ aggregators [aggr .Col ] = ag
452
514
fields [aggr .Col ].Type = targetType
453
515
if aggr .Alias != "" {
454
516
fields [aggr .Col ].Name = aggr .Alias
455
517
}
456
518
}
457
519
458
- for i , a := range agstate {
520
+ for i , a := range aggregators {
459
521
if a == nil {
460
- agstate [i ] = & aggregatorScalar {from : i }
522
+ aggregators [i ] = & aggregatorScalar {from : i }
461
523
}
462
524
}
463
525
464
- return agstate , fields , nil
526
+ return & aggregationState { aggregators : aggregators , env : env , coll : collation } , fields , nil
465
527
}
0 commit comments