@@ -95,9 +95,8 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro
9595 // 3) an index into selects
9696 // 4) a simple non-aggregate expression
9797 groupings := make ([]sql.Expression , 0 )
98- if fromScope .groupBy == nil {
99- fromScope .initGroupBy ()
100- }
98+ fromScope .initGroupBy ()
99+
101100 g := fromScope .groupBy
102101 for _ , e := range groupby {
103102 var col scopeColumn
@@ -194,9 +193,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
194193 // - grouping cols projection
195194 // - aggregate expressions
196195 // - output projection
197- if fromScope .groupBy == nil {
198- fromScope .initGroupBy ()
199- }
196+ fromScope .initGroupBy ()
200197
201198 group := fromScope .groupBy
202199 outScope := group .outScope
@@ -257,7 +254,10 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
257254 return outScope
258255}
259256
260- func isAggregateFunc (name string ) bool {
257+ // IsAggregateFunc is a hacky "extension point" to allow for other dialects to declare additional aggregate functions
258+ var IsAggregateFunc = IsMySQLAggregateFuncName
259+
260+ func IsMySQLAggregateFuncName (name string ) bool {
261261 switch name {
262262 case "avg" , "bit_and" , "bit_or" , "bit_xor" , "count" ,
263263 "group_concat" , "json_arrayagg" , "json_objectagg" ,
@@ -278,111 +278,63 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
278278 b .handleErr (err )
279279 }
280280
281- if inScope .groupBy == nil {
282- inScope .initGroupBy ()
283- }
281+ inScope .initGroupBy ()
284282 gb := inScope .groupBy
285283
286284 if strings .EqualFold (name , "count" ) {
287285 if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
288- var agg sql.Aggregation
289- if e .Distinct {
290- agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
291- } else {
292- agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
293- }
294- b .qFlags .Set (sql .QFlagCountStar )
295- aggName := strings .ToLower (agg .String ())
296- gf := gb .getAggRef (aggName )
297- if gf != nil {
298- // if we've already computed use reference here
299- return gf
300- }
301-
302- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
303- id := gb .outScope .newColumn (col )
304- col .id = id
305-
306- agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
307- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
308- col .scalar = agg
309-
310- gb .addAggStr (col )
311- return col .scalarGf ()
286+ return b .buildCountStarAggregate (e , gb )
312287 }
313288 }
314289
315290 if strings .EqualFold (name , "jsonarray" ) {
316291 // TODO we don't have any tests for this
317292 if _ , ok := e .Exprs [0 ].(* ast.StarExpr ); ok {
318- var agg sql.Aggregation
319- agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
320- b .qFlags .Set (sql .QFlagStar )
321-
322- //if e.Distinct {
323- // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
324- //}
325- aggName := strings .ToLower (agg .String ())
326- gf := gb .getAggRef (aggName )
327- if gf != nil {
328- // if we've already computed use reference here
329- return gf
330- }
331-
332- col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
333- id := gb .outScope .newColumn (col )
334-
335- agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
336- gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
337- col .scalar = agg
338-
339- col .id = id
340- gb .addAggStr (col )
341- return col .scalarGf ()
293+ return b .buildJsonArrayStarAggregate (gb )
342294 }
343295 }
344296
345297 if strings .EqualFold (name , "any_value" ) {
346298 b .qFlags .Set (sql .QFlagAnyAgg )
347299 }
348300
349- var args []sql.Expression
350- for _ , arg := range e .Exprs {
351- e := b .selectExprToExpression (inScope , arg )
352- switch e := e .(type ) {
353- case * expression.GetField :
354- if e .TableId () == 0 {
355- // TODO: not sure where this came from but it's not true
356- // aliases are not valid aggregate arguments, the alias must be masking a column
357- gf := b .selectExprToExpression (inScope .parent , arg )
358- var ok bool
359- e , ok = gf .(* expression.GetField )
360- if ! ok || e .TableId () == 0 {
361- b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
362- }
363- }
364- args = append (args , e )
365- col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
366- gb .addInCol (col )
367- case * expression.Star :
368- err := sql .ErrStarUnsupported .New ()
369- b .handleErr (err )
370- case * plan.Subquery :
371- args = append (args , e )
372- col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
373- gb .addInCol (col )
374- default :
375- args = append (args , e )
376- col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
377- gb .addInCol (col )
378- }
301+ args := b .buildAggFunctionArgs (inScope , e , gb )
302+ agg := b .newAggregation (e , name , args )
303+
304+ if name == "count" {
305+ b .qFlags .Set (sql .QFlagCount )
379306 }
380307
308+ aggType := agg .Type ()
309+ if name == "avg" || name == "sum" {
310+ aggType = types .Float64
311+ }
312+
313+ aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
314+ if id , ok := gb .outScope .getExpr (aggName , true ); ok {
315+ // if we've already computed use reference here
316+ gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
317+ return gf
318+ }
319+
320+ col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
321+ id := gb .outScope .newColumn (col )
322+
323+ agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
324+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
325+ col .scalar = agg
326+
327+ col .id = id
328+ gb .addAggStr (col )
329+ return col .scalarGf ()
330+ }
331+
332+ // newAggregation creates a new aggregation function instanc from the arguments given
333+ func (b * Builder ) newAggregation (e * ast.FuncExpr , name string , args []sql.Expression ) sql.Aggregation {
381334 var agg sql.Aggregation
382335 if e .Distinct && name == "count" {
383336 agg = aggregation .NewCountDistinct (args ... )
384337 } else {
385-
386338 // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
387339 // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
388340 if e .Distinct {
@@ -412,39 +364,104 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
412364 b .handleErr (err )
413365 }
414366 }
367+ return agg
368+ }
415369
416- if name == "count" {
417- b .qFlags .Set (sql .QFlagCount )
370+ // buildAggFunctionArgs builds the arguments for an aggregate function
371+ func (b * Builder ) buildAggFunctionArgs (inScope * scope , e * ast.FuncExpr , gb * groupBy ) []sql.Expression {
372+ var args []sql.Expression
373+ for _ , arg := range e .Exprs {
374+ e := b .selectExprToExpression (inScope , arg )
375+ switch e := e .(type ) {
376+ case * expression.GetField :
377+ if e .TableId () == 0 {
378+ // TODO: not sure where this came from but it's not true
379+ // aliases are not valid aggregate arguments, the alias must be masking a column
380+ gf := b .selectExprToExpression (inScope .parent , arg )
381+ var ok bool
382+ e , ok = gf .(* expression.GetField )
383+ if ! ok || e .TableId () == 0 {
384+ b .handleErr (fmt .Errorf ("failed to resolve aggregate column argument: %s" , gf ))
385+ }
386+ }
387+ args = append (args , e )
388+ col := scopeColumn {tableId : e .TableID (), db : e .Database (), table : e .Table (), col : e .Name (), scalar : e , typ : e .Type (), nullable : e .IsNullable ()}
389+ gb .addInCol (col )
390+ case * expression.Star :
391+ err := sql .ErrStarUnsupported .New ()
392+ b .handleErr (err )
393+ case * plan.Subquery :
394+ args = append (args , e )
395+ col := scopeColumn {col : e .QueryString , scalar : e , typ : e .Type ()}
396+ gb .addInCol (col )
397+ default :
398+ args = append (args , e )
399+ col := scopeColumn {col : e .String (), scalar : e , typ : e .Type ()}
400+ gb .addInCol (col )
401+ }
418402 }
403+ return args
404+ }
419405
420- aggType := agg .Type ()
421- if name == "avg" || name == "sum" {
422- aggType = types .Float64
406+ // buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
407+ func (b * Builder ) buildJsonArrayStarAggregate (gb * groupBy ) sql.Expression {
408+ var agg sql.Aggregation
409+ agg = aggregation .NewJsonArray (expression .NewLiteral (expression .NewStar (), types .Int64 ))
410+ b .qFlags .Set (sql .QFlagStar )
411+
412+ // if e.Distinct {
413+ // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
414+ // }
415+ aggName := strings .ToLower (agg .String ())
416+ gf := gb .getAggRef (aggName )
417+ if gf != nil {
418+ // if we've already computed use reference here
419+ return gf
423420 }
424421
425- aggName := strings .ToLower (plan .AliasSubqueryString (agg ))
426- if id , ok := gb .outScope .getExpr (aggName , true ); ok {
422+ col := scopeColumn {col : strings .ToLower (agg .String ()), scalar : agg , typ : agg .Type (), nullable : agg .IsNullable ()}
423+ id := gb .outScope .newColumn (col )
424+
425+ agg = agg .WithId (sql .ColumnId (id )).(* aggregation.JsonArray )
426+ gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
427+ col .scalar = agg
428+
429+ col .id = id
430+ gb .addAggStr (col )
431+ return col .scalarGf ()
432+ }
433+
434+ // buildCountStarAggregate builds a COUNT(*) aggregate function
435+ func (b * Builder ) buildCountStarAggregate (e * ast.FuncExpr , gb * groupBy ) sql.Expression {
436+ var agg sql.Aggregation
437+ if e .Distinct {
438+ agg = aggregation .NewCountDistinct (expression .NewLiteral (1 , types .Int64 ))
439+ } else {
440+ agg = aggregation .NewCount (expression .NewLiteral (1 , types .Int64 ))
441+ }
442+ b .qFlags .Set (sql .QFlagCountStar )
443+ aggName := strings .ToLower (agg .String ())
444+ gf := gb .getAggRef (aggName )
445+ if gf != nil {
427446 // if we've already computed use reference here
428- gf := expression .NewGetFieldWithTable (int (id ), 0 , aggType , "" , "" , aggName , agg .IsNullable ())
429447 return gf
430448 }
431449
432- col := scopeColumn {col : aggName , scalar : agg , typ : aggType , nullable : agg .IsNullable ()}
450+ col := scopeColumn {col : strings . ToLower ( agg . String ()) , scalar : agg , typ : agg . Type () , nullable : agg .IsNullable ()}
433451 id := gb .outScope .newColumn (col )
452+ col .id = id
434453
435454 agg = agg .WithId (sql .ColumnId (id )).(sql.Aggregation )
436455 gb .outScope .cols [len (gb .outScope .cols )- 1 ].scalar = agg
437456 col .scalar = agg
438457
439- col .id = id
440458 gb .addAggStr (col )
441459 return col .scalarGf ()
442460}
443461
462+ // buildGroupConcat builds a GROUP_CONCAT aggregate function
444463func (b * Builder ) buildGroupConcat (inScope * scope , e * ast.GroupConcatExpr ) sql.Expression {
445- if inScope .groupBy == nil {
446- inScope .initGroupBy ()
447- }
464+ inScope .initGroupBy ()
448465 gb := inScope .groupBy
449466
450467 args := make ([]sql.Expression , len (e .Exprs ))
@@ -794,7 +811,7 @@ func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where)
794811 return false , nil
795812 case * ast.FuncExpr :
796813 name := n .Name .Lowered ()
797- if isAggregateFunc (name ) {
814+ if IsAggregateFunc (name ) {
798815 // record aggregate
799816 // TODO: this should get projScope as well
800817 _ = b .buildAggregateFunc (fromScope , name , n )
@@ -874,9 +891,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
874891 if having == nil {
875892 return
876893 }
877- if fromScope .groupBy == nil {
878- fromScope .initGroupBy ()
879- }
894+ fromScope .initGroupBy ()
880895
881896 havingScope := b .newScope ()
882897 if fromScope .parent != nil {
0 commit comments