@@ -288,104 +288,58 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
288288
289289 if strings .EqualFold (name , "count" ) {
290290 if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
291- var agg sql.Aggregation
292- if e .Distinct {
293- agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
294- } else {
295- agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
296- }
297- b .qFlags .Set (sql .QFlagCountStar )
298- aggName := strings .ToLower (agg .String ())
299- gf := gb .getAggRef (aggName )
300- if gf != nil {
301- // if we've already computed use reference here
302- return gf
303- }
304-
305- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
306- id := gb .outScope .newColumn (col )
307- col .id = id
308-
309- agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
310- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
311- col .scalar = agg
312-
313- gb .addAggStr (col )
314- return col .scalarGf ()
291+ return b .buildCountStarAggregate (e , gb )
315292 }
316293 }
317294
318295 if strings .EqualFold (name , "jsonarray" ) {
319296 // TODO we don't have any tests for this
320297 if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
321- var agg sql.Aggregation
322- agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
323- b .qFlags .Set (sql .QFlagStar )
324-
325- //if e.Distinct {
326- // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
327- //}
328- aggName := strings .ToLower (agg .String ())
329- gf := gb .getAggRef (aggName )
330- if gf != nil {
331- // if we've already computed use reference here
332- return gf
333- }
334-
335- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
336- id := gb .outScope .newColumn (col )
337-
338- agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
339- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
340- col .scalar = agg
341-
342- col .id = id
343- gb .addAggStr (col )
344- return col .scalarGf ()
298+ return b .buildJsonArrayStarAggregate (gb )
345299 }
346300 }
347301
348302 if strings .EqualFold (name , "any_value" ) {
349303 b .qFlags .Set (sql .QFlagAnyAgg )
350304 }
351305
352- var args []sql.Expression
353- for _ , arg := range e .Exprs {
354- e := b .selectExprToExpression (inScope , arg )
355- switch e := e .(type ) {
356- case * expression.GetField :
357- if e .TableId () == 0 {
358- // TODO: not sure where this came from but it's not true
359- // aliases are not valid aggregate arguments, the alias must be masking a column
360- gf := b .selectExprToExpression (inScope .parent , arg )
361- var ok bool
362- e , ok = gf .(* expression.GetField )
363- if ! ok || e .TableId () == 0 {
364- b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
365- }
366- }
367- args = append (args , e )
368- col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
369- gb .addInCol (col )
370- case * expression.Star :
371- err := sql .ErrStarUnsupported .New ()
372- b .handleErr (err )
373- case * plan.Subquery :
374- args = append (args , e )
375- col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
376- gb .addInCol (col )
377- default :
378- args = append (args , e )
379- col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
380- gb .addInCol (col )
381- }
306+ args := b .buildAggFunctionArgs (inScope , e , gb )
307+ agg := b .newAggregation (e , name , args )
308+
309+ if name == "count" {
310+ b .qFlags .Set (sql .QFlagCount )
311+ }
312+
313+ aggType := agg .Type ()
314+ if name == "avg" || name == "sum" {
315+ aggType = types .Float64
382316 }
383317
318+ aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
319+ if id , ok := gb .outScope .getExpr (aggName , true ); ok {
320+ // if we've already computed use reference here
321+ gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
322+ return gf
323+ }
324+
325+ col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
326+ id := gb .outScope .newColumn (col )
327+
328+ agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
329+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
330+ col .scalar = agg
331+
332+ col .id = id
333+ gb .addAggStr (col )
334+ return col .scalarGf ()
335+ }
336+
337+ // newAggregation creates a new aggregation function instanc from the arguments given
338+ func (b * Builder ) newAggregation (e * ast.FuncExpr , name string , args []sql.Expression ) sql.Aggregation {
384339 var agg sql.Aggregation
385340 if e .Distinct && name == "count" {
386341 agg = aggregation .NewCountDistinct (args ... )
387342 } else {
388-
389343 // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
390344 // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
391345 if e .Distinct {
@@ -415,35 +369,102 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
415369 b .handleErr (err )
416370 }
417371 }
372+ return agg
373+ }
418374
419- if name == "count" {
420- b .qFlags .Set (sql .QFlagCount )
375+ // buildAggFunctionArgs builds the arguments for an aggregate function
376+ func (b * Builder ) buildAggFunctionArgs (inScope * scope , e * ast.FuncExpr , gb * groupBy ) []sql.Expression {
377+ var args []sql.Expression
378+ for _ , arg := range e .Exprs {
379+ e := b .selectExprToExpression (inScope , arg )
380+ switch e := e .(type ) {
381+ case * expression.GetField :
382+ if e .TableId () == 0 {
383+ // TODO: not sure where this came from but it's not true
384+ // aliases are not valid aggregate arguments, the alias must be masking a column
385+ gf := b .selectExprToExpression (inScope .parent , arg )
386+ var ok bool
387+ e , ok = gf .(* expression.GetField )
388+ if ! ok || e .TableId () == 0 {
389+ b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
390+ }
391+ }
392+ args = append (args , e )
393+ col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
394+ gb .addInCol (col )
395+ case * expression.Star :
396+ err := sql .ErrStarUnsupported .New ()
397+ b .handleErr (err )
398+ case * plan.Subquery :
399+ args = append (args , e )
400+ col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
401+ gb .addInCol (col )
402+ default :
403+ args = append (args , e )
404+ col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
405+ gb .addInCol (col )
406+ }
421407 }
408+ return args
409+ }
422410
423- aggType := agg .Type ()
424- if name == "avg" || name == "sum" {
425- aggType = types .Float64
411+ // buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
412+ func (b * Builder ) buildJsonArrayStarAggregate (gb * groupBy ) sql.Expression {
413+ var agg sql.Aggregation
414+ agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
415+ b .qFlags .Set (sql .QFlagStar )
416+
417+ // if e.Distinct {
418+ // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
419+ // }
420+ aggName := strings .ToLower (agg .String ())
421+ gf := gb .getAggRef (aggName )
422+ if gf != nil {
423+ // if we've already computed use reference here
424+ return gf
426425 }
427426
428- aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
429- if id , ok := gb .outScope .getExpr (aggName , true ); ok {
427+ col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
428+ id := gb .outScope .newColumn (col )
429+
430+ agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
431+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
432+ col .scalar = agg
433+
434+ col .id = id
435+ gb .addAggStr (col )
436+ return col .scalarGf ()
437+ }
438+
439+ // buildCountStarAggregate builds a COUNT(*) aggregate function
440+ func (b * Builder ) buildCountStarAggregate (e * ast.FuncExpr , gb * groupBy ) sql.Expression {
441+ var agg sql.Aggregation
442+ if e .Distinct {
443+ agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
444+ } else {
445+ agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
446+ }
447+ b .qFlags .Set (sql .QFlagCountStar )
448+ aggName := strings .ToLower (agg .String ())
449+ gf := gb .getAggRef (aggName )
450+ if gf != nil {
430451 // if we've already computed use reference here
431- gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
432452 return gf
433453 }
434454
435- col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
455+ col := scopeColumn {col : strings . ToLower ( agg . String ()) , scalar : agg , typ : agg . Type () , nullable : agg .IsNullable ()}
436456 id := gb .outScope .newColumn (col )
457+ col .id = id
437458
438459 agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
439460 gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
440461 col .scalar = agg
441462
442- col .id = id
443463 gb .addAggStr (col )
444464 return col .scalarGf ()
445465}
446466
467+ // buildGroupConcat builds a GROUP_CONCAT aggregate function
447468func (b * Builder ) buildGroupConcat (inScope * scope , e * ast.GroupConcatExpr ) sql.Expression {
448469 if inScope .groupBy == nil {
449470 inScope .initGroupBy ()
0 commit comments