diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index f7a772fb86..986b4b5579 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -9352,6 +9352,20 @@ from typestable`, {"0"}, {"1"}, {"0"}, {"1"}, }, }, + // https://github.com/dolthub/dolt/issues/7095 + // References in group by and having should be allowed to match select aliases + { + Query: "select y as z from xy group by (y) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y as z from xy group by (z) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1", + Expected: []sql.Row{{2}, {3}, {4}}, + }, } var KeylessQueries = []QueryTest{ @@ -9603,20 +9617,6 @@ FROM mytable;`, {"DECIMAL"}, }, }, - // https://github.com/dolthub/dolt/issues/7095 - // References in group by and having should be allowed to match select aliases - { - Query: "select y as z from xy group by (y) having AVG(z) > 0", - Expected: []sql.Row{{1}, {2}, {3}}, - }, - { - Query: "select y as z from xy group by (z) having AVG(z) > 0", - Expected: []sql.Row{{1}, {2}, {3}}, - }, - { - Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1", - Expected: []sql.Row{{2}, {3}, {4}}, - }, } var VersionedQueries = []QueryTest{ diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 110abb8a11..2f7395dfa0 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -372,17 +372,14 @@ func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *grou var args []sql.Expression for _, arg := range e.Exprs { e := b.selectExprToExpression(inScope, arg) + // if GetField is an alias, alias must be masking a column + if gf, ok := e.(*expression.GetField); ok && gf.TableId() == 0 { + e = b.selectExprToExpression(inScope.parent, arg) + } switch e := e.(type) { case *expression.GetField: if e.TableId() == 0 { - // TODO: not sure where this came from but it's not true - // aliases are not valid aggregate arguments, the alias must be masking a column - gf := b.selectExprToExpression(inScope.parent, arg) - var ok bool - e, ok = gf.(*expression.GetField) - if !ok || e.TableId() == 0 { - b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) - } + b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", e)) } args = append(args, e) col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} @@ -953,6 +950,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast havingScope := b.newScope() if fromScope.parent != nil { havingScope.parent = fromScope.parent + havingScope.parent.selectAliases = fromScope.selectAliases } // add columns from fromScope referenced in the groupBy diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index e80234ce89..b9af043ae7 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -160,6 +160,10 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se col.scalar = e tempScope.addColumn(col) } + if inScope.selectAliases == nil { + inScope.selectAliases = make(map[string]sql.Expression) + } + inScope.selectAliases[e.Name()] = e exprs = append(exprs, e) default: exprs = append(exprs, pe) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 88b3715f29..f62eed18a5 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -123,14 +123,15 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { colName := strings.ToLower(v.Name.String()) c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) if !ok { + if aliasedExpr, ok := inScope.selectAliases[colName]; ok { + return aliasedExpr + } sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) if ok { return sysVar } var err error - if scope == ast.SetScope_User { - err = sql.ErrUnknownUserVariable.New(colName) - } else if scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { + if scope == ast.SetScope_User || scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { err = sql.ErrUnknownUserVariable.New(colName) } else if scope == ast.SetScope_Global || scope == ast.SetScope_Session { err = sql.ErrUnknownSystemVariable.New(colName) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 5e941f1dda..f9ad2360fa 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -61,6 +61,8 @@ type scope struct { insertTableAlias string insertColumnAliases map[string]string + + selectAliases map[string]sql.Expression } // resolveColumn matches a variable use to a column definition with a unique @@ -441,6 +443,12 @@ func (s *scope) copy() *scope { if !s.colset.Empty() { ret.colset = s.colset.Copy() } + if s.selectAliases != nil { + ret.selectAliases = make(map[string]sql.Expression, len(s.selectAliases)) + for k, v := range s.selectAliases { + ret.selectAliases[k] = v + } + } return &ret } @@ -644,8 +652,9 @@ func (c scopeColumn) withOriginal(origTbl, col string) scopeColumn { // scalarGf returns a getField reference to this column's expression. func (c scopeColumn) scalarGf() sql.Expression { if c.scalar != nil { - if p, ok := c.scalar.(*expression.ProcedureParam); ok { - return p + switch e := c.scalar.(type) { + case *expression.ProcedureParam: + return e } } if c.originalCol != "" {