@@ -288,104 +288,58 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
288
288
289
289
if strings .EqualFold (name , "count" ) {
290
290
if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
291
- var agg sql.Aggregation
292
- if e .Distinct {
293
- agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
294
- } else {
295
- agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
296
- }
297
- b .qFlags .Set (sql .QFlagCountStar )
298
- aggName := strings .ToLower (agg .String ())
299
- gf := gb .getAggRef (aggName )
300
- if gf != nil {
301
- // if we've already computed use reference here
302
- return gf
303
- }
304
-
305
- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
306
- id := gb .outScope .newColumn (col )
307
- col .id = id
308
-
309
- agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
310
- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
311
- col .scalar = agg
312
-
313
- gb .addAggStr (col )
314
- return col .scalarGf ()
291
+ return b .buildCountStarAggregate (e , gb )
315
292
}
316
293
}
317
294
318
295
if strings .EqualFold (name , "jsonarray" ) {
319
296
// TODO we don't have any tests for this
320
297
if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
321
- var agg sql.Aggregation
322
- agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
323
- b .qFlags .Set (sql .QFlagStar )
324
-
325
- //if e.Distinct {
326
- // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
327
- //}
328
- aggName := strings .ToLower (agg .String ())
329
- gf := gb .getAggRef (aggName )
330
- if gf != nil {
331
- // if we've already computed use reference here
332
- return gf
333
- }
334
-
335
- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
336
- id := gb .outScope .newColumn (col )
337
-
338
- agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
339
- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
340
- col .scalar = agg
341
-
342
- col .id = id
343
- gb .addAggStr (col )
344
- return col .scalarGf ()
298
+ return b .buildJsonArrayStarAggregate (gb )
345
299
}
346
300
}
347
301
348
302
if strings .EqualFold (name , "any_value" ) {
349
303
b .qFlags .Set (sql .QFlagAnyAgg )
350
304
}
351
305
352
- var args []sql.Expression
353
- for _ , arg := range e .Exprs {
354
- e := b .selectExprToExpression (inScope , arg )
355
- switch e := e .(type ) {
356
- case * expression.GetField :
357
- if e .TableId () == 0 {
358
- // TODO: not sure where this came from but it's not true
359
- // aliases are not valid aggregate arguments, the alias must be masking a column
360
- gf := b .selectExprToExpression (inScope .parent , arg )
361
- var ok bool
362
- e , ok = gf .(* expression.GetField )
363
- if ! ok || e .TableId () == 0 {
364
- b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
365
- }
366
- }
367
- args = append (args , e )
368
- col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
369
- gb .addInCol (col )
370
- case * expression.Star :
371
- err := sql .ErrStarUnsupported .New ()
372
- b .handleErr (err )
373
- case * plan.Subquery :
374
- args = append (args , e )
375
- col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
376
- gb .addInCol (col )
377
- default :
378
- args = append (args , e )
379
- col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
380
- gb .addInCol (col )
381
- }
306
+ args := b .buildAggFunctionArgs (inScope , e , gb )
307
+ agg := b .newAggregation (e , name , args )
308
+
309
+ if name == "count" {
310
+ b .qFlags .Set (sql .QFlagCount )
311
+ }
312
+
313
+ aggType := agg .Type ()
314
+ if name == "avg" || name == "sum" {
315
+ aggType = types .Float64
382
316
}
383
317
318
+ aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
319
+ if id , ok := gb .outScope .getExpr (aggName , true ); ok {
320
+ // if we've already computed use reference here
321
+ gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
322
+ return gf
323
+ }
324
+
325
+ col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
326
+ id := gb .outScope .newColumn (col )
327
+
328
+ agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
329
+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
330
+ col .scalar = agg
331
+
332
+ col .id = id
333
+ gb .addAggStr (col )
334
+ return col .scalarGf ()
335
+ }
336
+
337
+ // newAggregation creates a new aggregation function instanc from the arguments given
338
+ func (b * Builder ) newAggregation (e * ast.FuncExpr , name string , args []sql.Expression ) sql.Aggregation {
384
339
var agg sql.Aggregation
385
340
if e .Distinct && name == "count" {
386
341
agg = aggregation .NewCountDistinct (args ... )
387
342
} else {
388
-
389
343
// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
390
344
// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
391
345
if e .Distinct {
@@ -415,35 +369,102 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
415
369
b .handleErr (err )
416
370
}
417
371
}
372
+ return agg
373
+ }
418
374
419
- if name == "count" {
420
- b .qFlags .Set (sql .QFlagCount )
375
+ // buildAggFunctionArgs builds the arguments for an aggregate function
376
+ func (b * Builder ) buildAggFunctionArgs (inScope * scope , e * ast.FuncExpr , gb * groupBy ) []sql.Expression {
377
+ var args []sql.Expression
378
+ for _ , arg := range e .Exprs {
379
+ e := b .selectExprToExpression (inScope , arg )
380
+ switch e := e .(type ) {
381
+ case * expression.GetField :
382
+ if e .TableId () == 0 {
383
+ // TODO: not sure where this came from but it's not true
384
+ // aliases are not valid aggregate arguments, the alias must be masking a column
385
+ gf := b .selectExprToExpression (inScope .parent , arg )
386
+ var ok bool
387
+ e , ok = gf .(* expression.GetField )
388
+ if ! ok || e .TableId () == 0 {
389
+ b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
390
+ }
391
+ }
392
+ args = append (args , e )
393
+ col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
394
+ gb .addInCol (col )
395
+ case * expression.Star :
396
+ err := sql .ErrStarUnsupported .New ()
397
+ b .handleErr (err )
398
+ case * plan.Subquery :
399
+ args = append (args , e )
400
+ col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
401
+ gb .addInCol (col )
402
+ default :
403
+ args = append (args , e )
404
+ col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
405
+ gb .addInCol (col )
406
+ }
421
407
}
408
+ return args
409
+ }
422
410
423
- aggType := agg .Type ()
424
- if name == "avg" || name == "sum" {
425
- aggType = types .Float64
411
+ // buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
412
+ func (b * Builder ) buildJsonArrayStarAggregate (gb * groupBy ) sql.Expression {
413
+ var agg sql.Aggregation
414
+ agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
415
+ b .qFlags .Set (sql .QFlagStar )
416
+
417
+ // if e.Distinct {
418
+ // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
419
+ // }
420
+ aggName := strings .ToLower (agg .String ())
421
+ gf := gb .getAggRef (aggName )
422
+ if gf != nil {
423
+ // if we've already computed use reference here
424
+ return gf
426
425
}
427
426
428
- aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
429
- if id , ok := gb .outScope .getExpr (aggName , true ); ok {
427
+ col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
428
+ id := gb .outScope .newColumn (col )
429
+
430
+ agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
431
+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
432
+ col .scalar = agg
433
+
434
+ col .id = id
435
+ gb .addAggStr (col )
436
+ return col .scalarGf ()
437
+ }
438
+
439
+ // buildCountStarAggregate builds a COUNT(*) aggregate function
440
+ func (b * Builder ) buildCountStarAggregate (e * ast.FuncExpr , gb * groupBy ) sql.Expression {
441
+ var agg sql.Aggregation
442
+ if e .Distinct {
443
+ agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
444
+ } else {
445
+ agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
446
+ }
447
+ b .qFlags .Set (sql .QFlagCountStar )
448
+ aggName := strings .ToLower (agg .String ())
449
+ gf := gb .getAggRef (aggName )
450
+ if gf != nil {
430
451
// if we've already computed use reference here
431
- gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
432
452
return gf
433
453
}
434
454
435
- col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
455
+ col := scopeColumn {col : strings . ToLower ( agg . String ()) , scalar : agg , typ : agg . Type () , nullable : agg .IsNullable ()}
436
456
id := gb .outScope .newColumn (col )
457
+ col .id = id
437
458
438
459
agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
439
460
gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
440
461
col .scalar = agg
441
462
442
- col .id = id
443
463
gb .addAggStr (col )
444
464
return col .scalarGf ()
445
465
}
446
466
467
+ // buildGroupConcat builds a GROUP_CONCAT aggregate function
447
468
func (b * Builder ) buildGroupConcat (inScope * scope , e * ast.GroupConcatExpr ) sql.Expression {
448
469
if inScope .groupBy == nil {
449
470
inScope .initGroupBy ()
0 commit comments