Skip to content

Commit 6d94caa

Browse files
authored
Merge pull request #3073 from dolthub/angela/binding
Allow select aliases to be in group by/having
2 parents cf94656 + 580e1ee commit 6d94caa

File tree

5 files changed

+39
-27
lines changed

5 files changed

+39
-27
lines changed

enginetest/queries/queries.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9352,6 +9352,20 @@ from typestable`,
93529352
{"0"}, {"1"}, {"0"}, {"1"},
93539353
},
93549354
},
9355+
// https://github.com/dolthub/dolt/issues/7095
9356+
// References in group by and having should be allowed to match select aliases
9357+
{
9358+
Query: "select y as z from xy group by (y) having AVG(z) > 0",
9359+
Expected: []sql.Row{{1}, {2}, {3}},
9360+
},
9361+
{
9362+
Query: "select y as z from xy group by (z) having AVG(z) > 0",
9363+
Expected: []sql.Row{{1}, {2}, {3}},
9364+
},
9365+
{
9366+
Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1",
9367+
Expected: []sql.Row{{2}, {3}, {4}},
9368+
},
93559369
}
93569370

93579371
var KeylessQueries = []QueryTest{
@@ -9603,20 +9617,6 @@ FROM mytable;`,
96039617
{"DECIMAL"},
96049618
},
96059619
},
9606-
// https://github.com/dolthub/dolt/issues/7095
9607-
// References in group by and having should be allowed to match select aliases
9608-
{
9609-
Query: "select y as z from xy group by (y) having AVG(z) > 0",
9610-
Expected: []sql.Row{{1}, {2}, {3}},
9611-
},
9612-
{
9613-
Query: "select y as z from xy group by (z) having AVG(z) > 0",
9614-
Expected: []sql.Row{{1}, {2}, {3}},
9615-
},
9616-
{
9617-
Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1",
9618-
Expected: []sql.Row{{2}, {3}, {4}},
9619-
},
96209620
}
96219621

96229622
var VersionedQueries = []QueryTest{

sql/planbuilder/aggregates.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,17 +372,14 @@ func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *grou
372372
var args []sql.Expression
373373
for _, arg := range e.Exprs {
374374
e := b.selectExprToExpression(inScope, arg)
375+
// if GetField is an alias, alias must be masking a column
376+
if gf, ok := e.(*expression.GetField); ok && gf.TableId() == 0 {
377+
e = b.selectExprToExpression(inScope.parent, arg)
378+
}
375379
switch e := e.(type) {
376380
case *expression.GetField:
377381
if e.TableId() == 0 {
378-
// TODO: not sure where this came from but it's not true
379-
// aliases are not valid aggregate arguments, the alias must be masking a column
380-
gf := b.selectExprToExpression(inScope.parent, arg)
381-
var ok bool
382-
e, ok = gf.(*expression.GetField)
383-
if !ok || e.TableId() == 0 {
384-
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
385-
}
382+
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", e))
386383
}
387384
args = append(args, e)
388385
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
@@ -953,6 +950,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
953950
havingScope := b.newScope()
954951
if fromScope.parent != nil {
955952
havingScope.parent = fromScope.parent
953+
havingScope.parent.selectAliases = fromScope.selectAliases
956954
}
957955

958956
// add columns from fromScope referenced in the groupBy

sql/planbuilder/project.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se
160160
col.scalar = e
161161
tempScope.addColumn(col)
162162
}
163+
if inScope.selectAliases == nil {
164+
inScope.selectAliases = make(map[string]sql.Expression)
165+
}
166+
inScope.selectAliases[e.Name()] = e
163167
exprs = append(exprs, e)
164168
default:
165169
exprs = append(exprs, pe)

sql/planbuilder/scalar.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,15 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
123123
colName := strings.ToLower(v.Name.String())
124124
c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false)
125125
if !ok {
126+
if aliasedExpr, ok := inScope.selectAliases[colName]; ok {
127+
return aliasedExpr
128+
}
126129
sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None)
127130
if ok {
128131
return sysVar
129132
}
130133
var err error
131-
if scope == ast.SetScope_User {
132-
err = sql.ErrUnknownUserVariable.New(colName)
133-
} else if scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly {
134+
if scope == ast.SetScope_User || scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly {
134135
err = sql.ErrUnknownUserVariable.New(colName)
135136
} else if scope == ast.SetScope_Global || scope == ast.SetScope_Session {
136137
err = sql.ErrUnknownSystemVariable.New(colName)

sql/planbuilder/scope.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ type scope struct {
6161

6262
insertTableAlias string
6363
insertColumnAliases map[string]string
64+
65+
selectAliases map[string]sql.Expression
6466
}
6567

6668
// resolveColumn matches a variable use to a column definition with a unique
@@ -441,6 +443,12 @@ func (s *scope) copy() *scope {
441443
if !s.colset.Empty() {
442444
ret.colset = s.colset.Copy()
443445
}
446+
if s.selectAliases != nil {
447+
ret.selectAliases = make(map[string]sql.Expression, len(s.selectAliases))
448+
for k, v := range s.selectAliases {
449+
ret.selectAliases[k] = v
450+
}
451+
}
444452

445453
return &ret
446454
}
@@ -644,8 +652,9 @@ func (c scopeColumn) withOriginal(origTbl, col string) scopeColumn {
644652
// scalarGf returns a getField reference to this column's expression.
645653
func (c scopeColumn) scalarGf() sql.Expression {
646654
if c.scalar != nil {
647-
if p, ok := c.scalar.(*expression.ProcedureParam); ok {
648-
return p
655+
switch e := c.scalar.(type) {
656+
case *expression.ProcedureParam:
657+
return e
649658
}
650659
}
651660
if c.originalCol != "" {

0 commit comments

Comments
 (0)