diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 662b1f6027..3278015d38 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -95,9 +95,8 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro // 3) an index into selects // 4) a simple non-aggregate expression groupings := make([]sql.Expression, 0) - if fromScope.groupBy == nil { - fromScope.initGroupBy() - } + fromScope.initGroupBy() + g := fromScope.groupBy for _, e := range groupby { var col scopeColumn @@ -194,9 +193,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s // - grouping cols projection // - aggregate expressions // - output projection - if fromScope.groupBy == nil { - fromScope.initGroupBy() - } + fromScope.initGroupBy() group := fromScope.groupBy outScope := group.outScope @@ -257,7 +254,10 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s return outScope } -func isAggregateFunc(name string) bool { +// IsAggregateFunc is a hacky "extension point" to allow for other dialects to declare additional aggregate functions +var IsAggregateFunc = IsMySQLAggregateFuncName + +func IsMySQLAggregateFuncName(name string) bool { switch name { case "avg", "bit_and", "bit_or", "bit_xor", "count", "group_concat", "json_arrayagg", "json_objectagg", @@ -278,67 +278,19 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp b.handleErr(err) } - if inScope.groupBy == nil { - inScope.initGroupBy() - } + inScope.initGroupBy() gb := inScope.groupBy if strings.EqualFold(name, "count") { if _, ok := e.Exprs[0].(*ast.StarExpr); ok { - var agg sql.Aggregation - if e.Distinct { - agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64)) - } else { - agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64)) - } - b.qFlags.Set(sql.QFlagCountStar) - aggName := strings.ToLower(agg.String()) - gf := gb.getAggRef(aggName) - if gf != nil { - // if we've already computed use reference here - return gf - } - - col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} - id := gb.outScope.newColumn(col) - col.id = id - - agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation) - gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg - col.scalar = agg - - gb.addAggStr(col) - return col.scalarGf() + return b.buildCountStarAggregate(e, gb) } } if strings.EqualFold(name, "jsonarray") { // TODO we don't have any tests for this if _, ok := e.Exprs[0].(*ast.StarExpr); ok { - var agg sql.Aggregation - agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64)) - b.qFlags.Set(sql.QFlagStar) - - //if e.Distinct { - // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64)) - //} - aggName := strings.ToLower(agg.String()) - gf := gb.getAggRef(aggName) - if gf != nil { - // if we've already computed use reference here - return gf - } - - col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} - id := gb.outScope.newColumn(col) - - agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray) - gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg - col.scalar = agg - - col.id = id - gb.addAggStr(col) - return col.scalarGf() + return b.buildJsonArrayStarAggregate(gb) } } @@ -346,43 +298,43 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp b.qFlags.Set(sql.QFlagAnyAgg) } - var args []sql.Expression - for _, arg := range e.Exprs { - e := b.selectExprToExpression(inScope, arg) - switch e := e.(type) { - case *expression.GetField: - if e.TableId() == 0 { - // TODO: not sure where this came from but it's not true - // aliases are not valid aggregate arguments, the alias must be masking a column - gf := b.selectExprToExpression(inScope.parent, arg) - var ok bool - e, ok = gf.(*expression.GetField) - if !ok || e.TableId() == 0 { - b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) - } - } - args = append(args, e) - col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} - gb.addInCol(col) - case *expression.Star: - err := sql.ErrStarUnsupported.New() - b.handleErr(err) - case *plan.Subquery: - args = append(args, e) - col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()} - gb.addInCol(col) - default: - args = append(args, e) - col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()} - gb.addInCol(col) - } + args := b.buildAggFunctionArgs(inScope, e, gb) + agg := b.newAggregation(e, name, args) + + if name == "count" { + b.qFlags.Set(sql.QFlagCount) } + aggType := agg.Type() + if name == "avg" || name == "sum" { + aggType = types.Float64 + } + + aggName := strings.ToLower(plan.AliasSubqueryString(agg)) + if id, ok := gb.outScope.getExpr(aggName, true); ok { + // if we've already computed use reference here + gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable()) + return gf + } + + col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()} + id := gb.outScope.newColumn(col) + + agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation) + gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg + col.scalar = agg + + col.id = id + gb.addAggStr(col) + return col.scalarGf() +} + +// newAggregation creates a new aggregation function instanc from the arguments given +func (b *Builder) newAggregation(e *ast.FuncExpr, name string, args []sql.Expression) sql.Aggregation { var agg sql.Aggregation if e.Distinct && name == "count" { agg = aggregation.NewCountDistinct(args...) } else { - // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT. if e.Distinct { @@ -412,39 +364,104 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp b.handleErr(err) } } + return agg +} - if name == "count" { - b.qFlags.Set(sql.QFlagCount) +// buildAggFunctionArgs builds the arguments for an aggregate function +func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *groupBy) []sql.Expression { + var args []sql.Expression + for _, arg := range e.Exprs { + e := b.selectExprToExpression(inScope, arg) + switch e := e.(type) { + case *expression.GetField: + if e.TableId() == 0 { + // TODO: not sure where this came from but it's not true + // aliases are not valid aggregate arguments, the alias must be masking a column + gf := b.selectExprToExpression(inScope.parent, arg) + var ok bool + e, ok = gf.(*expression.GetField) + if !ok || e.TableId() == 0 { + b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) + } + } + args = append(args, e) + col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} + gb.addInCol(col) + case *expression.Star: + err := sql.ErrStarUnsupported.New() + b.handleErr(err) + case *plan.Subquery: + args = append(args, e) + col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()} + gb.addInCol(col) + default: + args = append(args, e) + col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()} + gb.addInCol(col) + } } + return args +} - aggType := agg.Type() - if name == "avg" || name == "sum" { - aggType = types.Float64 +// buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function +func (b *Builder) buildJsonArrayStarAggregate(gb *groupBy) sql.Expression { + var agg sql.Aggregation + agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64)) + b.qFlags.Set(sql.QFlagStar) + + // if e.Distinct { + // agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64)) + // } + aggName := strings.ToLower(agg.String()) + gf := gb.getAggRef(aggName) + if gf != nil { + // if we've already computed use reference here + return gf } - aggName := strings.ToLower(plan.AliasSubqueryString(agg)) - if id, ok := gb.outScope.getExpr(aggName, true); ok { + col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} + id := gb.outScope.newColumn(col) + + agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray) + gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg + col.scalar = agg + + col.id = id + gb.addAggStr(col) + return col.scalarGf() +} + +// buildCountStarAggregate builds a COUNT(*) aggregate function +func (b *Builder) buildCountStarAggregate(e *ast.FuncExpr, gb *groupBy) sql.Expression { + var agg sql.Aggregation + if e.Distinct { + agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64)) + } else { + agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64)) + } + b.qFlags.Set(sql.QFlagCountStar) + aggName := strings.ToLower(agg.String()) + gf := gb.getAggRef(aggName) + if gf != nil { // if we've already computed use reference here - gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable()) return gf } - col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()} + col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()} id := gb.outScope.newColumn(col) + col.id = id agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation) gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg col.scalar = agg - col.id = id gb.addAggStr(col) return col.scalarGf() } +// buildGroupConcat builds a GROUP_CONCAT aggregate function func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression { - if inScope.groupBy == nil { - inScope.initGroupBy() - } + inScope.initGroupBy() gb := inScope.groupBy args := make([]sql.Expression, len(e.Exprs)) @@ -794,7 +811,7 @@ func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where) return false, nil case *ast.FuncExpr: name := n.Name.Lowered() - if isAggregateFunc(name) { + if IsAggregateFunc(name) { // record aggregate // TODO: this should get projScope as well _ = b.buildAggregateFunc(fromScope, name, n) @@ -874,9 +891,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast if having == nil { return } - if fromScope.groupBy == nil { - fromScope.initGroupBy() - } + fromScope.initGroupBy() havingScope := b.newScope() if fromScope.parent != nil { diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 5706d54ee1..f4a2e10a29 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -153,7 +153,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { return b.buildNameConst(inScope, v) } else if name == "icu_version" { return expression.NewLiteral(icuVersion, types.MustCreateString(query.Type_VARCHAR, int64(len(icuVersion)), sql.Collation_Default)) - } else if isAggregateFunc(name) && v.Over == nil { + } else if IsAggregateFunc(name) && v.Over == nil { // TODO this assumes aggregate is in the same scope // also need to avoid nested aggregates return b.buildAggregateFunc(inScope, name, v) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index d64868be11..5e941f1dda 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -224,7 +224,9 @@ func (s *scope) initProc() { // initGroupBy creates a container scope for aggregation // functions and function inputs. func (s *scope) initGroupBy() { - s.groupBy = &groupBy{outScope: s.replace()} + if s.groupBy == nil { + s.groupBy = &groupBy{outScope: s.replace()} + } } // pushSubquery creates a new scope with the subquery already initialized. diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index 33b8d1bfdc..3669658d00 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -614,7 +614,7 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression { return expression.NewLiteral(v.String(), types.LongText) case *ast.FuncExpr: // todo(max): more specific validation for nested ASOF functions - if isWindowFunc(v.Name.Lowered()) || isAggregateFunc(v.Name.Lowered()) { + if isWindowFunc(v.Name.Lowered()) || IsAggregateFunc(v.Name.Lowered()) { err := sql.ErrInvalidAsOfExpression.New(v) b.handleErr(err) }