From 2ae3d4763ccef887367a1616d2e47948387135b0 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Mon, 7 Jul 2025 17:31:12 -0700 Subject: [PATCH 1/6] allow select aliases to be in group by/having --- sql/planbuilder/project.go | 6 +++++- sql/planbuilder/scalar.go | 8 +++++--- sql/planbuilder/scope.go | 9 +++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index 898273d714..b833d51799 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -135,7 +135,7 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se err := sql.ErrColumnNotFound.New(gf.String()) b.handleErr(err) } - col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), scalar: e, typ: gf.Type(), nullable: gf.IsNullable()} + col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), nullable: gf.IsNullable(), originalCol: gf.Name()} } else if sq, ok := e.Child.(*plan.Subquery); ok { col = scopeColumn{col: e.Name(), scalar: e, typ: sq.Type(), nullable: sq.IsNullable()} } else { @@ -151,6 +151,10 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se col.scalar = e tempScope.addColumn(col) } + if inScope.selectColumnAliases == nil { + inScope.selectColumnAliases = make(map[string]scopeColumn) + } + inScope.selectColumnAliases[e.Name()] = col exprs = append(exprs, e) default: exprs = append(exprs, pe) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 60a94d94b8..889d3f9410 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -123,14 +123,16 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { colName := strings.ToLower(v.Name.String()) c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) if !ok { + alias, ok := inScope.selectColumnAliases[colName] + if ok { + return alias.scalarGf() + } sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) if ok { return sysVar } var err error - if scope == ast.SetScope_User { - err = sql.ErrUnknownUserVariable.New(colName) - } else if scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { + if scope == ast.SetScope_User || scope == ast.SetScope_Persist || scope == ast.SetScope_PersistOnly { err = sql.ErrUnknownUserVariable.New(colName) } else if scope == ast.SetScope_Global || scope == ast.SetScope_Session { err = sql.ErrUnknownSystemVariable.New(colName) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 5e941f1dda..3b14b5118c 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -61,6 +61,8 @@ type scope struct { insertTableAlias string insertColumnAliases map[string]string + + selectColumnAliases map[string]scopeColumn } // resolveColumn matches a variable use to a column definition with a unique @@ -644,8 +646,11 @@ func (c scopeColumn) withOriginal(origTbl, col string) scopeColumn { // scalarGf returns a getField reference to this column's expression. func (c scopeColumn) scalarGf() sql.Expression { if c.scalar != nil { - if p, ok := c.scalar.(*expression.ProcedureParam); ok { - return p + switch e := c.scalar.(type) { + case *expression.ProcedureParam: + return e + case *expression.Alias: + return e.Child } } if c.originalCol != "" { From 0a910f3aa9016879a2924aa587c79b96b861c167 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Mon, 7 Jul 2025 17:45:39 -0700 Subject: [PATCH 2/6] no longer panics but fails a bunch of tests --- sql/planbuilder/project.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index b833d51799..42f5378286 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -135,7 +135,7 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se err := sql.ErrColumnNotFound.New(gf.String()) b.handleErr(err) } - col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), nullable: gf.IsNullable(), originalCol: gf.Name()} + col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), scalar: e, nullable: gf.IsNullable(), originalCol: gf.Name()} } else if sq, ok := e.Child.(*plan.Subquery); ok { col = scopeColumn{col: e.Name(), scalar: e, typ: sq.Type(), nullable: sq.IsNullable()} } else { From 7ea932a326fd3934500179fc3796340154d317e2 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Tue, 8 Jul 2025 13:17:31 -0700 Subject: [PATCH 3/6] passes most tests --- sql/planbuilder/aggregates.go | 20 ++++++++++---------- sql/planbuilder/project.go | 2 +- sql/planbuilder/scalar.go | 2 +- sql/planbuilder/scope.go | 8 +++++++- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 110abb8a11..4e2aa5c47f 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -374,16 +374,16 @@ func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *grou e := b.selectExprToExpression(inScope, arg) switch e := e.(type) { case *expression.GetField: - if e.TableId() == 0 { - // TODO: not sure where this came from but it's not true - // aliases are not valid aggregate arguments, the alias must be masking a column - gf := b.selectExprToExpression(inScope.parent, arg) - var ok bool - e, ok = gf.(*expression.GetField) - if !ok || e.TableId() == 0 { - b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) - } - } + //if e.TableId() == 0 { + // // TODO: not sure where this came from but it's not true + // // aliases are not valid aggregate arguments, the alias must be masking a column + // gf := b.selectExprToExpression(inScope.parent, arg) + // // var ok bool + // e, ok := gf.(*expression.GetField) + // if !ok || e.TableId() == 0 { + // b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) + // } + //} args = append(args, e) col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} gb.addInCol(col) diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index 42f5378286..6596b65710 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -135,7 +135,7 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se err := sql.ErrColumnNotFound.New(gf.String()) b.handleErr(err) } - col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), scalar: e, nullable: gf.IsNullable(), originalCol: gf.Name()} + col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), scalar: e, nullable: gf.IsNullable()} } else if sq, ok := e.Child.(*plan.Subquery); ok { col = scopeColumn{col: e.Name(), scalar: e, typ: sq.Type(), nullable: sq.IsNullable()} } else { diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 889d3f9410..43d180b0ae 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -125,7 +125,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { if !ok { alias, ok := inScope.selectColumnAliases[colName] if ok { - return alias.scalarGf() + return alias.scalar } sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) if ok { diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 3b14b5118c..5dddb765e9 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -443,6 +443,12 @@ func (s *scope) copy() *scope { if !s.colset.Empty() { ret.colset = s.colset.Copy() } + if s.selectColumnAliases != nil { + ret.selectColumnAliases = make(map[string]scopeColumn, len(s.selectColumnAliases)) + for k, v := range s.selectColumnAliases { + ret.selectColumnAliases[k] = v + } + } return &ret } @@ -650,7 +656,7 @@ func (c scopeColumn) scalarGf() sql.Expression { case *expression.ProcedureParam: return e case *expression.Alias: - return e.Child + return e } } if c.originalCol != "" { From dbe6a427863159607888606282f0a7cc30c61645 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Wed, 13 Aug 2025 15:40:12 -0700 Subject: [PATCH 4/6] pass all tests --- sql/planbuilder/aggregates.go | 17 +++++++---------- sql/planbuilder/project.go | 2 +- sql/planbuilder/scalar.go | 3 +-- sql/planbuilder/scope.go | 2 -- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 4e2aa5c47f..001fa4a729 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -372,18 +372,14 @@ func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *grou var args []sql.Expression for _, arg := range e.Exprs { e := b.selectExprToExpression(inScope, arg) + if gf, ok := e.(*expression.GetField); ok && gf.TableId() == 0 { + e = b.selectExprToExpression(inScope.parent, arg) + } switch e := e.(type) { case *expression.GetField: - //if e.TableId() == 0 { - // // TODO: not sure where this came from but it's not true - // // aliases are not valid aggregate arguments, the alias must be masking a column - // gf := b.selectExprToExpression(inScope.parent, arg) - // // var ok bool - // e, ok := gf.(*expression.GetField) - // if !ok || e.TableId() == 0 { - // b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf)) - // } - //} + if e.TableId() == 0 { + b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", e)) + } args = append(args, e) col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()} gb.addInCol(col) @@ -953,6 +949,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast havingScope := b.newScope() if fromScope.parent != nil { havingScope.parent = fromScope.parent + havingScope.parent.selectColumnAliases = fromScope.selectColumnAliases } // add columns from fromScope referenced in the groupBy diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index a491f3282d..9b9753bf96 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -144,7 +144,7 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se err := sql.ErrColumnNotFound.New(gf.String()) b.handleErr(err) } - col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), typ: gf.Type(), scalar: e, nullable: gf.IsNullable()} + col = scopeColumn{id: id, tableId: gf.TableId(), col: e.Name(), db: gf.Database(), table: gf.Table(), scalar: e, typ: gf.Type(), nullable: gf.IsNullable()} } else if sq, ok := e.Child.(*plan.Subquery); ok { col = scopeColumn{col: e.Name(), scalar: e, typ: sq.Type(), nullable: sq.IsNullable()} } else { diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 47944d54f2..2954e2f5ab 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -123,8 +123,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { colName := strings.ToLower(v.Name.String()) c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) if !ok { - alias, ok := inScope.selectColumnAliases[colName] - if ok { + if alias, ok := inScope.selectColumnAliases[colName]; ok { return alias.scalar } sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 5dddb765e9..33bcdfb66f 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -655,8 +655,6 @@ func (c scopeColumn) scalarGf() sql.Expression { switch e := c.scalar.(type) { case *expression.ProcedureParam: return e - case *expression.Alias: - return e } } if c.originalCol != "" { From 1832a13c91d958ba5c0b7bd89481fb3d55b993da Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Wed, 13 Aug 2025 15:52:04 -0700 Subject: [PATCH 5/6] unskip tests --- enginetest/queries/queries.go | 28 ++++++++++++++-------------- sql/planbuilder/aggregates.go | 1 + 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index f7a772fb86..986b4b5579 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -9352,6 +9352,20 @@ from typestable`, {"0"}, {"1"}, {"0"}, {"1"}, }, }, + // https://github.com/dolthub/dolt/issues/7095 + // References in group by and having should be allowed to match select aliases + { + Query: "select y as z from xy group by (y) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y as z from xy group by (z) having AVG(z) > 0", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1", + Expected: []sql.Row{{2}, {3}, {4}}, + }, } var KeylessQueries = []QueryTest{ @@ -9603,20 +9617,6 @@ FROM mytable;`, {"DECIMAL"}, }, }, - // https://github.com/dolthub/dolt/issues/7095 - // References in group by and having should be allowed to match select aliases - { - Query: "select y as z from xy group by (y) having AVG(z) > 0", - Expected: []sql.Row{{1}, {2}, {3}}, - }, - { - Query: "select y as z from xy group by (z) having AVG(z) > 0", - Expected: []sql.Row{{1}, {2}, {3}}, - }, - { - Query: "select y + 1 as z from xy group by (z) having AVG(z) > 1", - Expected: []sql.Row{{2}, {3}, {4}}, - }, } var VersionedQueries = []QueryTest{ diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 001fa4a729..47ebbbf2bc 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -372,6 +372,7 @@ func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *grou var args []sql.Expression for _, arg := range e.Exprs { e := b.selectExprToExpression(inScope, arg) + // if GetField is an alias, alias must be masking a column if gf, ok := e.(*expression.GetField); ok && gf.TableId() == 0 { e = b.selectExprToExpression(inScope.parent, arg) } From 580e1ee00f7438fe643d355cf6f7a70ab14f9df3 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Wed, 13 Aug 2025 15:58:32 -0700 Subject: [PATCH 6/6] only store scalar expression instead of entire column --- sql/planbuilder/aggregates.go | 2 +- sql/planbuilder/project.go | 6 +++--- sql/planbuilder/scalar.go | 4 ++-- sql/planbuilder/scope.go | 10 +++++----- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/planbuilder/aggregates.go b/sql/planbuilder/aggregates.go index 47ebbbf2bc..2f7395dfa0 100644 --- a/sql/planbuilder/aggregates.go +++ b/sql/planbuilder/aggregates.go @@ -950,7 +950,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast havingScope := b.newScope() if fromScope.parent != nil { havingScope.parent = fromScope.parent - havingScope.parent.selectColumnAliases = fromScope.selectColumnAliases + havingScope.parent.selectAliases = fromScope.selectAliases } // add columns from fromScope referenced in the groupBy diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index 9b9753bf96..b9af043ae7 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -160,10 +160,10 @@ func (b *Builder) analyzeSelectList(inScope, outScope *scope, selectExprs ast.Se col.scalar = e tempScope.addColumn(col) } - if inScope.selectColumnAliases == nil { - inScope.selectColumnAliases = make(map[string]scopeColumn) + if inScope.selectAliases == nil { + inScope.selectAliases = make(map[string]sql.Expression) } - inScope.selectColumnAliases[e.Name()] = col + inScope.selectAliases[e.Name()] = e exprs = append(exprs, e) default: exprs = append(exprs, pe) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 2954e2f5ab..f62eed18a5 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -123,8 +123,8 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { colName := strings.ToLower(v.Name.String()) c, ok := inScope.resolveColumn(dbName, tblName, colName, true, false) if !ok { - if alias, ok := inScope.selectColumnAliases[colName]; ok { - return alias.scalar + if aliasedExpr, ok := inScope.selectAliases[colName]; ok { + return aliasedExpr } sysVar, scope, ok := b.buildSysVar(v, ast.SetScope_None) if ok { diff --git a/sql/planbuilder/scope.go b/sql/planbuilder/scope.go index 33bcdfb66f..f9ad2360fa 100644 --- a/sql/planbuilder/scope.go +++ b/sql/planbuilder/scope.go @@ -62,7 +62,7 @@ type scope struct { insertTableAlias string insertColumnAliases map[string]string - selectColumnAliases map[string]scopeColumn + selectAliases map[string]sql.Expression } // resolveColumn matches a variable use to a column definition with a unique @@ -443,10 +443,10 @@ func (s *scope) copy() *scope { if !s.colset.Empty() { ret.colset = s.colset.Copy() } - if s.selectColumnAliases != nil { - ret.selectColumnAliases = make(map[string]scopeColumn, len(s.selectColumnAliases)) - for k, v := range s.selectColumnAliases { - ret.selectColumnAliases[k] = v + if s.selectAliases != nil { + ret.selectAliases = make(map[string]sql.Expression, len(s.selectAliases)) + for k, v := range s.selectAliases { + ret.selectAliases[k] = v } }