Skip to content

Commit c9ae582

Browse files
committed
Refactoring: split up aggregation building
1 parent f063d44 commit c9ae582

File tree

1 file changed

+110
-89
lines changed

1 file changed

+110
-89
lines changed

sql/planbuilder/aggregates.go

Lines changed: 110 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -288,104 +288,58 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
288288

289289
if strings.EqualFold(name, "count") {
290290
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
291-
var agg sql.Aggregation
292-
if e.Distinct {
293-
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
294-
} else {
295-
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
296-
}
297-
b.qFlags.Set(sql.QFlagCountStar)
298-
aggName := strings.ToLower(agg.String())
299-
gf := gb.getAggRef(aggName)
300-
if gf != nil {
301-
// if we've already computed use reference here
302-
return gf
303-
}
304-
305-
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
306-
id := gb.outScope.newColumn(col)
307-
col.id = id
308-
309-
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
310-
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
311-
col.scalar = agg
312-
313-
gb.addAggStr(col)
314-
return col.scalarGf()
291+
return b.buildCountStarAggregate(e, gb)
315292
}
316293
}
317294

318295
if strings.EqualFold(name, "jsonarray") {
319296
// TODO we don't have any tests for this
320297
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
321-
var agg sql.Aggregation
322-
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
323-
b.qFlags.Set(sql.QFlagStar)
324-
325-
//if e.Distinct {
326-
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
327-
//}
328-
aggName := strings.ToLower(agg.String())
329-
gf := gb.getAggRef(aggName)
330-
if gf != nil {
331-
// if we've already computed use reference here
332-
return gf
333-
}
334-
335-
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
336-
id := gb.outScope.newColumn(col)
337-
338-
agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
339-
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
340-
col.scalar = agg
341-
342-
col.id = id
343-
gb.addAggStr(col)
344-
return col.scalarGf()
298+
return b.buildJsonArrayStarAggregate(gb)
345299
}
346300
}
347301

348302
if strings.EqualFold(name, "any_value") {
349303
b.qFlags.Set(sql.QFlagAnyAgg)
350304
}
351305

352-
var args []sql.Expression
353-
for _, arg := range e.Exprs {
354-
e := b.selectExprToExpression(inScope, arg)
355-
switch e := e.(type) {
356-
case *expression.GetField:
357-
if e.TableId() == 0 {
358-
// TODO: not sure where this came from but it's not true
359-
// aliases are not valid aggregate arguments, the alias must be masking a column
360-
gf := b.selectExprToExpression(inScope.parent, arg)
361-
var ok bool
362-
e, ok = gf.(*expression.GetField)
363-
if !ok || e.TableId() == 0 {
364-
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
365-
}
366-
}
367-
args = append(args, e)
368-
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
369-
gb.addInCol(col)
370-
case *expression.Star:
371-
err := sql.ErrStarUnsupported.New()
372-
b.handleErr(err)
373-
case *plan.Subquery:
374-
args = append(args, e)
375-
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
376-
gb.addInCol(col)
377-
default:
378-
args = append(args, e)
379-
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
380-
gb.addInCol(col)
381-
}
306+
args := b.buildAggFunctionArgs(inScope, e, gb)
307+
agg := b.newAggregation(e, name, args)
308+
309+
if name == "count" {
310+
b.qFlags.Set(sql.QFlagCount)
311+
}
312+
313+
aggType := agg.Type()
314+
if name == "avg" || name == "sum" {
315+
aggType = types.Float64
382316
}
383317

318+
aggName := strings.ToLower(plan.AliasSubqueryString(agg))
319+
if id, ok := gb.outScope.getExpr(aggName, true); ok {
320+
// if we've already computed use reference here
321+
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
322+
return gf
323+
}
324+
325+
col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
326+
id := gb.outScope.newColumn(col)
327+
328+
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
329+
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
330+
col.scalar = agg
331+
332+
col.id = id
333+
gb.addAggStr(col)
334+
return col.scalarGf()
335+
}
336+
337+
// newAggregation creates a new aggregation function instanc from the arguments given
338+
func (b *Builder) newAggregation(e *ast.FuncExpr, name string, args []sql.Expression) sql.Aggregation {
384339
var agg sql.Aggregation
385340
if e.Distinct && name == "count" {
386341
agg = aggregation.NewCountDistinct(args...)
387342
} else {
388-
389343
// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
390344
// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
391345
if e.Distinct {
@@ -415,35 +369,102 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
415369
b.handleErr(err)
416370
}
417371
}
372+
return agg
373+
}
418374

419-
if name == "count" {
420-
b.qFlags.Set(sql.QFlagCount)
375+
// buildAggFunctionArgs builds the arguments for an aggregate function
376+
func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *groupBy) []sql.Expression {
377+
var args []sql.Expression
378+
for _, arg := range e.Exprs {
379+
e := b.selectExprToExpression(inScope, arg)
380+
switch e := e.(type) {
381+
case *expression.GetField:
382+
if e.TableId() == 0 {
383+
// TODO: not sure where this came from but it's not true
384+
// aliases are not valid aggregate arguments, the alias must be masking a column
385+
gf := b.selectExprToExpression(inScope.parent, arg)
386+
var ok bool
387+
e, ok = gf.(*expression.GetField)
388+
if !ok || e.TableId() == 0 {
389+
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
390+
}
391+
}
392+
args = append(args, e)
393+
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
394+
gb.addInCol(col)
395+
case *expression.Star:
396+
err := sql.ErrStarUnsupported.New()
397+
b.handleErr(err)
398+
case *plan.Subquery:
399+
args = append(args, e)
400+
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
401+
gb.addInCol(col)
402+
default:
403+
args = append(args, e)
404+
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
405+
gb.addInCol(col)
406+
}
421407
}
408+
return args
409+
}
422410

423-
aggType := agg.Type()
424-
if name == "avg" || name == "sum" {
425-
aggType = types.Float64
411+
// buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
412+
func (b *Builder) buildJsonArrayStarAggregate(gb *groupBy) sql.Expression {
413+
var agg sql.Aggregation
414+
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
415+
b.qFlags.Set(sql.QFlagStar)
416+
417+
// if e.Distinct {
418+
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
419+
// }
420+
aggName := strings.ToLower(agg.String())
421+
gf := gb.getAggRef(aggName)
422+
if gf != nil {
423+
// if we've already computed use reference here
424+
return gf
426425
}
427426

428-
aggName := strings.ToLower(plan.AliasSubqueryString(agg))
429-
if id, ok := gb.outScope.getExpr(aggName, true); ok {
427+
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
428+
id := gb.outScope.newColumn(col)
429+
430+
agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
431+
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
432+
col.scalar = agg
433+
434+
col.id = id
435+
gb.addAggStr(col)
436+
return col.scalarGf()
437+
}
438+
439+
// buildCountStarAggregate builds a COUNT(*) aggregate function
440+
func (b *Builder) buildCountStarAggregate(e *ast.FuncExpr, gb *groupBy) sql.Expression {
441+
var agg sql.Aggregation
442+
if e.Distinct {
443+
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
444+
} else {
445+
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
446+
}
447+
b.qFlags.Set(sql.QFlagCountStar)
448+
aggName := strings.ToLower(agg.String())
449+
gf := gb.getAggRef(aggName)
450+
if gf != nil {
430451
// if we've already computed use reference here
431-
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
432452
return gf
433453
}
434454

435-
col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
455+
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
436456
id := gb.outScope.newColumn(col)
457+
col.id = id
437458

438459
agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
439460
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
440461
col.scalar = agg
441462

442-
col.id = id
443463
gb.addAggStr(col)
444464
return col.scalarGf()
445465
}
446466

467+
// buildGroupConcat builds a GROUP_CONCAT aggregate function
447468
func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression {
448469
if inScope.groupBy == nil {
449470
inScope.initGroupBy()

0 commit comments

Comments
 (0)