Skip to content

Commit 1d50de4

Browse files
committed
renamed SelectedExprs to projectedDeps (SelectedExprs weren't actually the selected exprs)
1 parent d3e7ae8 commit 1d50de4

File tree

7 files changed

+52
-36
lines changed

7 files changed

+52
-36
lines changed

sql/analyzer/replace_count_star.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func replaceCountStar(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope,
4040
qFlags.Set(sql.QFlagMax1Row)
4141
}
4242

43-
if len(agg.SelectedExprs) == 1 && len(agg.GroupByExprs) == 0 {
44-
child := agg.SelectedExprs[0]
43+
if len(agg.ProjectedExprs()) == 1 && len(agg.GroupByExprs) == 0 {
44+
child := agg.ProjectedExprs()[0]
4545
var cnt *aggregation.Count
4646
name := ""
4747
if alias, ok := child.(*expression.Alias); ok {

sql/analyzer/replace_sort.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope,
201201
return n, transform.SameTree, nil
202202
}
203203
// TODO: optimize when there are multiple aggregations; use LATERAL JOINS
204-
if len(gb.SelectedExprs) != 1 || len(gb.GroupByExprs) != 0 {
204+
if len(gb.ProjectedExprs()) != 1 || len(gb.GroupByExprs) != 0 {
205205
return n, transform.SameTree, nil
206206
}
207207

@@ -237,7 +237,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope,
237237

238238
// generate sort fields from aggregations
239239
var sf sql.SortField
240-
switch agg := gb.SelectedExprs[0].(type) {
240+
switch agg := gb.ProjectedExprs()[0].(type) {
241241
case *aggregation.Max:
242242
gf, ok := agg.UnaryExpression.Child.(*expression.GetField)
243243
if !ok {
@@ -268,7 +268,7 @@ func replaceAgg(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope,
268268
}
269269

270270
// replace all aggs in proj.Projections with GetField
271-
name := gb.SelectedExprs[0].String()
271+
name := gb.ProjectedExprs()[0].String()
272272
newProjs, _, err := transform.Exprs(proj.Projections, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
273273
if strings.EqualFold(e.String(), name) {
274274
return sf.Column, transform.NewTree, nil

sql/analyzer/resolve_ctes.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func schemaLength(node sql.Node) int {
3535
schemaLen = len(node.Projections)
3636
return false
3737
case *plan.GroupBy:
38-
schemaLen = len(node.SelectedExprs)
38+
schemaLen = len(node.ProjectedExprs())
3939
return false
4040
case *plan.Window:
4141
schemaLen = len(node.SelectExprs)

sql/analyzer/unnest_insubqueries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ func getHighestProjection(n sql.Node) (sql.Expression, bool, error) {
306306
// todo(max): could make better effort to get column ids from these,
307307
// but real fix is also giving synthesized projection column ids
308308
// in binder
309-
proj = nn.SelectedExprs
309+
proj = nn.ProjectedExprs()
310310
case *plan.Window:
311311
proj = nn.SelectExprs
312312
case *plan.SetOp:

sql/analyzer/validation_rules.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,34 +251,50 @@ func validateGroupBy(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
251251
var err error
252252
//var parent sql.Node
253253
transform.Inspect(n, func(n sql.Node) bool {
254-
//defer func() {
255-
// parent = n
256-
//}()
254+
/*defer func() {
255+
parent = n
256+
}()*/
257257

258258
gb, ok := n.(*plan.GroupBy)
259259
if !ok {
260260
return true
261261
}
262262

263-
//switch parent.(type) {
264-
//case *plan.Having, *plan.Project, *plan.Sort:
265-
// // TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value
266-
// // https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key
267-
// return true
268-
//}
263+
/*switch parent.(type) {
264+
case *plan.Having, *plan.Project, *plan.Sort:
265+
// TODO: these shouldn't be skipped; you can group by primary key without problem b/c only one value
266+
// https://dev.mysql.com/doc/refman/8.0/en/group-by-handling.html#:~:text=The%20query%20is%20valid%20if%20name%20is%20a%20primary%20key
267+
return true
268+
} */
269269

270270
// Allow the parser use the GroupBy node to eval the aggregation functions
271271
// for sql statements that don't make use of the GROUP BY expression.
272272
if len(gb.GroupByExprs) == 0 {
273273
return true
274274
}
275275

276+
primaryKeys := make(map[string]bool)
277+
for _, col := range gb.Child.Schema() {
278+
if col.PrimaryKey {
279+
primaryKeys[strings.ToLower(col.Name)] = true
280+
}
281+
}
282+
276283
var groupBys []string
284+
groupByPrimaryKeys := 0
277285
for _, expr := range gb.GroupByExprs {
278286
groupBys = append(groupBys, expr.String())
287+
if primaryKeys[strings.ToLower(expr.String())] {
288+
groupByPrimaryKeys++
289+
}
290+
291+
}
292+
293+
if len(primaryKeys) != 0 && groupByPrimaryKeys == len(primaryKeys) {
294+
return true
279295
}
280296

281-
for _, expr := range gb.SelectedExprs {
297+
for _, expr := range gb.ProjectedExprs() {
282298
if _, ok := expr.(sql.Aggregation); !ok {
283299
if !expressionReferencesOnlyGroupBys(groupBys, expr) {
284300
err = analyzererrors.ErrValidationGroupBy.New(expr.String())

sql/plan/group_by.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ var ErrGroupBy = errors.NewKind("group by aggregation '%v' not supported")
3030
// GroupBy groups the rows by some expressions.
3131
type GroupBy struct {
3232
UnaryNode
33-
SelectedExprs []sql.Expression
33+
projectedDeps []sql.Expression
3434
GroupByExprs []sql.Expression
3535
}
3636

@@ -46,15 +46,15 @@ var _ sql.CollationCoercible = (*GroupBy)(nil)
4646
func NewGroupBy(selectedExprs, groupByExprs []sql.Expression, child sql.Node) *GroupBy {
4747
return &GroupBy{
4848
UnaryNode: UnaryNode{Child: child},
49-
SelectedExprs: selectedExprs,
49+
projectedDeps: selectedExprs,
5050
GroupByExprs: groupByExprs,
5151
}
5252
}
5353

5454
// Resolved implements the Resolvable interface.
5555
func (g *GroupBy) Resolved() bool {
5656
return g.UnaryNode.Child.Resolved() &&
57-
expression.ExpressionsResolved(g.SelectedExprs...) &&
57+
expression.ExpressionsResolved(g.projectedDeps...) &&
5858
expression.ExpressionsResolved(g.GroupByExprs...)
5959
}
6060

@@ -64,8 +64,8 @@ func (g *GroupBy) IsReadOnly() bool {
6464

6565
// Schema implements the Node interface.
6666
func (g *GroupBy) Schema() sql.Schema {
67-
var s = make(sql.Schema, len(g.SelectedExprs))
68-
for i, e := range g.SelectedExprs {
67+
var s = make(sql.Schema, len(g.projectedDeps))
68+
for i, e := range g.projectedDeps {
6969
var name string
7070
if n, ok := e.(sql.Nameable); ok {
7171
name = n.Name()
@@ -101,7 +101,7 @@ func (g *GroupBy) WithChildren(children ...sql.Node) (sql.Node, error) {
101101
return nil, sql.ErrInvalidChildrenNumber.New(g, len(children), 1)
102102
}
103103

104-
return NewGroupBy(g.SelectedExprs, g.GroupByExprs, children[0]), nil
104+
return NewGroupBy(g.projectedDeps, g.GroupByExprs, children[0]), nil
105105
}
106106

107107
// CollationCoercibility implements the interface sql.CollationCoercible.
@@ -111,16 +111,16 @@ func (g *GroupBy) CollationCoercibility(ctx *sql.Context) (collation sql.Collati
111111

112112
// WithExpressions implements the Node interface.
113113
func (g *GroupBy) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
114-
expected := len(g.SelectedExprs) + len(g.GroupByExprs)
114+
expected := len(g.projectedDeps) + len(g.GroupByExprs)
115115
if len(exprs) != expected {
116116
return nil, sql.ErrInvalidChildrenNumber.New(g, len(exprs), expected)
117117
}
118118

119-
agg := make([]sql.Expression, len(g.SelectedExprs))
120-
copy(agg, exprs[:len(g.SelectedExprs)])
119+
agg := make([]sql.Expression, len(g.projectedDeps))
120+
copy(agg, exprs[:len(g.projectedDeps)])
121121

122122
grouping := make([]sql.Expression, len(g.GroupByExprs))
123-
copy(grouping, exprs[len(g.SelectedExprs):])
123+
copy(grouping, exprs[len(g.projectedDeps):])
124124

125125
return NewGroupBy(agg, grouping, g.Child), nil
126126
}
@@ -129,8 +129,8 @@ func (g *GroupBy) String() string {
129129
pr := sql.NewTreePrinter()
130130
_ = pr.WriteNode("GroupBy")
131131

132-
var selectedExprs = make([]string, len(g.SelectedExprs))
133-
for i, e := range g.SelectedExprs {
132+
var selectedExprs = make([]string, len(g.projectedDeps))
133+
for i, e := range g.projectedDeps {
134134
selectedExprs[i] = e.String()
135135
}
136136

@@ -151,8 +151,8 @@ func (g *GroupBy) DebugString() string {
151151
pr := sql.NewTreePrinter()
152152
_ = pr.WriteNode("GroupBy")
153153

154-
var selectedExprs = make([]string, len(g.SelectedExprs))
155-
for i, e := range g.SelectedExprs {
154+
var selectedExprs = make([]string, len(g.projectedDeps))
155+
for i, e := range g.projectedDeps {
156156
selectedExprs[i] = sql.DebugString(e)
157157
}
158158

@@ -172,12 +172,12 @@ func (g *GroupBy) DebugString() string {
172172
// Expressions implements the Expressioner interface.
173173
func (g *GroupBy) Expressions() []sql.Expression {
174174
var exprs []sql.Expression
175-
exprs = append(exprs, g.SelectedExprs...)
175+
exprs = append(exprs, g.projectedDeps...)
176176
exprs = append(exprs, g.GroupByExprs...)
177177
return exprs
178178
}
179179

180180
// ProjectedExprs implements the sql.Projector interface
181181
func (g *GroupBy) ProjectedExprs() []sql.Expression {
182-
return g.SelectedExprs
182+
return g.projectedDeps
183183
}

sql/rowexec/rel.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ func (b *BaseBuilder) buildSet(ctx *sql.Context, n *plan.Set, row sql.Row) (sql.
394394
func (b *BaseBuilder) buildGroupBy(ctx *sql.Context, n *plan.GroupBy, row sql.Row) (sql.RowIter, error) {
395395
span, ctx := ctx.Span("plan.GroupBy", trace.WithAttributes(
396396
attribute.Int("groupings", len(n.GroupByExprs)),
397-
attribute.Int("aggregates", len(n.SelectedExprs)),
397+
attribute.Int("aggregates", len(n.ProjectedExprs())),
398398
))
399399

400400
i, err := b.buildNodeExec(ctx, n.Child, row)
@@ -405,9 +405,9 @@ func (b *BaseBuilder) buildGroupBy(ctx *sql.Context, n *plan.GroupBy, row sql.Ro
405405

406406
var iter sql.RowIter
407407
if len(n.GroupByExprs) == 0 {
408-
iter = newGroupByIter(n.SelectedExprs, i)
408+
iter = newGroupByIter(n.ProjectedExprs(), i)
409409
} else {
410-
iter = newGroupByGroupingIter(ctx, n.SelectedExprs, n.GroupByExprs, i)
410+
iter = newGroupByGroupingIter(ctx, n.ProjectedExprs(), n.GroupByExprs, i)
411411
}
412412

413413
return sql.NewSpanIter(span, iter), nil

0 commit comments

Comments
 (0)