Skip to content

Commit 7c1a82b

Browse files
committed
Track scope for CTEs to fix ambiguous column bug
1 parent b807fe9 commit 7c1a82b

File tree

13 files changed

+370
-57
lines changed

13 files changed

+370
-57
lines changed

internal/compiler/analyze.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
153153
if err := check(validate.In(c.catalog, raw)); err != nil {
154154
return nil, err
155155
}
156-
rvs := rangeVars(raw.Stmt)
156+
scopedRVs := rangeVarsWithScope(raw.Stmt)
157157
refs, errs := findParameters(raw.Stmt)
158158
if len(errs) > 0 {
159159
if failfast {
@@ -173,7 +173,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
173173
return nil, err
174174
}
175175

176-
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
176+
params, err := c.resolveCatalogRefs(qc, scopedRVs, refs, namedParams, embeds)
177177
if err := check(err); err != nil {
178178
return nil, err
179179
}

internal/compiler/find_params.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type paramRef struct {
2424
rv *ast.RangeVar
2525
ref *ast.ParamRef
2626
name string // Named parameter support
27+
28+
cteName *string // Current CTE name, nil if not inside a CTE.
2729
}
2830

2931
type paramSearch struct {
@@ -36,6 +38,8 @@ type paramSearch struct {
3638
// XXX: Gross state hack for limit
3739
limitCount ast.Node
3840
limitOffset ast.Node
41+
42+
cteName *string // Current CTE name, nil if not inside a CTE.
3943
}
4044

4145
type limitCount struct {
@@ -55,6 +59,10 @@ func (l *limitOffset) Pos() int {
5559
func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
5660
switch n := node.(type) {
5761

62+
case *ast.CommonTableExpr:
63+
p.cteName = n.Ctename
64+
return p
65+
5866
case *ast.A_Expr:
5967
p.parent = node
6068

@@ -87,7 +95,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
8795
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
8896
return p
8997
}
90-
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
98+
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName})
9199
p.seen[ref.Location] = struct{}{}
92100
}
93101
for _, item := range s.ValuesLists.Items {
@@ -104,7 +112,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
104112
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
105113
return p
106114
}
107-
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
115+
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName})
108116
p.seen[ref.Location] = struct{}{}
109117
}
110118
}
@@ -125,7 +133,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
125133
if !ok {
126134
continue
127135
}
128-
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
136+
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, cteName: p.cteName})
129137
}
130138
p.seen[ref.Location] = struct{}{}
131139
}
@@ -186,7 +194,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
186194
}
187195

188196
if set {
189-
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
197+
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, cteName: p.cteName})
190198
p.seen[n.Location] = struct{}{}
191199
}
192200
return nil

internal/compiler/output_columns.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
2222
return nil, err
2323
}
2424

25+
return convertColumnsToCatalog(cols), nil
26+
}
27+
28+
func convertColumnsToCatalog(cols []*Column) []*catalog.Column {
2529
catCols := make([]*catalog.Column, 0, len(cols))
2630
for _, col := range cols {
2731
catCols = append(catCols, &catalog.Column{
@@ -35,7 +39,7 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
3539
Length: col.Length,
3640
})
3741
}
38-
return catCols, nil
42+
return catCols
3943
}
4044

4145
func hasStarRef(cf *ast.ColumnRef) bool {

internal/compiler/parse.go

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,36 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
132132
}, nil
133133
}
134134

135-
func rangeVars(root ast.Node) []*ast.RangeVar {
136-
var vars []*ast.RangeVar
137-
find := astutils.VisitorFunc(func(node ast.Node) {
138-
switch n := node.(type) {
139-
case *ast.RangeVar:
140-
vars = append(vars, n)
141-
}
142-
})
143-
astutils.Walk(find, root)
144-
return vars
135+
// scopedRangeVar associates a RangeVar with a scope.
136+
type scopedRangeVar struct {
137+
rv *ast.RangeVar
138+
139+
cteName *string // Current CTE name, nil if not inside a CTE.
140+
}
141+
142+
// rangeVarsWithScope collects all RangeVars with their scope.
143+
func rangeVarsWithScope(root ast.Node) []scopedRangeVar {
144+
var rvs []scopedRangeVar
145+
visitor := &rvSearch{rvs: &rvs, cteName: nil}
146+
astutils.Walk(visitor, root)
147+
return rvs
148+
}
149+
150+
// rvSearch finds all RangeVars and tracks their scope.
151+
type rvSearch struct {
152+
rvs *[]scopedRangeVar
153+
154+
cteName *string // Current CTE name, nil if not inside a CTE.
155+
}
156+
157+
func (v *rvSearch) Visit(node ast.Node) astutils.Visitor {
158+
switch n := node.(type) {
159+
case *ast.CommonTableExpr:
160+
return &rvSearch{rvs: v.rvs, cteName: n.Ctename}
161+
case *ast.RangeVar:
162+
*v.rvs = append(*v.rvs, scopedRangeVar{rv: n, cteName: v.cteName})
163+
}
164+
return v
145165
}
146166

147167
func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {

internal/compiler/resolve.go

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,63 @@ func dataType(n *ast.TypeName) string {
2121
}
2222
}
2323

24-
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
24+
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, scopedRVs []scopedRangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
2525
c := comp.catalog
2626

27+
scopeMap := make(map[*string][]*ast.TableName)
28+
outerAliasMap := map[string]*ast.TableName{}
2729
aliasMap := map[string]*ast.TableName{}
30+
tableNameMap := map[string]*ast.TableName{}
2831
// TODO: Deprecate defaultTable
2932
var defaultTable *ast.TableName
3033
var tables []*ast.TableName
31-
3234
typeMap := map[string]map[string]map[string]*catalog.Column{}
33-
indexTable := func(table catalog.Table) error {
34-
tables = append(tables, table.Rel)
35+
36+
indexTableWithColumns := func(rel *ast.TableName, cols []*catalog.Column) error {
37+
tables = append(tables, rel)
38+
tableNameMap[rel.Name] = rel
3539
if defaultTable == nil {
36-
defaultTable = table.Rel
40+
defaultTable = rel
3741
}
38-
schema := table.Rel.Schema
42+
schema := rel.Schema
3943
if schema == "" {
4044
schema = c.DefaultSchema
4145
}
4246
if _, exists := typeMap[schema]; !exists {
4347
typeMap[schema] = map[string]map[string]*catalog.Column{}
4448
}
45-
typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{}
46-
for _, c := range table.Columns {
47-
cc := c
48-
typeMap[schema][table.Rel.Name][c.Name] = cc
49+
typeMap[schema][rel.Name] = map[string]*catalog.Column{}
50+
for _, c := range cols {
51+
typeMap[schema][rel.Name][c.Name] = c
4952
}
5053
return nil
5154
}
5255

53-
for _, rv := range rvs {
56+
indexTable := func(table catalog.Table) error {
57+
return indexTableWithColumns(table.Rel, table.Columns)
58+
}
59+
60+
indexCTE := func(cte *Table) error {
61+
catalogCols := convertColumnsToCatalog(cte.Columns)
62+
return indexTableWithColumns(cte.Rel, catalogCols)
63+
}
64+
65+
for _, srv := range scopedRVs {
66+
rv := srv.rv
67+
scope := srv.cteName
5468
if rv.Relname == nil {
5569
continue
5670
}
5771
fqn, err := ParseTableName(rv)
5872
if err != nil {
5973
return nil, err
6074
}
75+
76+
scopeMap[scope] = append(scopeMap[scope], fqn)
77+
if scope == nil && rv.Alias != nil {
78+
outerAliasMap[*rv.Alias.Aliasname] = fqn
79+
}
80+
6181
if _, found := aliasMap[fqn.Name]; found {
6282
continue
6383
}
@@ -67,9 +87,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
6787
continue
6888
}
6989
// If the table name doesn't exist, first check if it's a CTE
70-
if _, qcerr := qc.GetTable(fqn); qcerr != nil {
90+
cte, qcerr := qc.GetTable(fqn)
91+
if qcerr != nil {
7192
return nil, err
7293
}
94+
if err := indexCTE(cte); err != nil {
95+
return nil, err
96+
}
97+
if rv.Alias != nil {
98+
aliasMap[*rv.Alias.Aliasname] = fqn
99+
}
73100
continue
74101
}
75102
err = indexTable(table)
@@ -89,7 +116,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
89116
continue
90117
}
91118

92-
if alias, ok := aliasMap[embed.Table.Name]; ok {
119+
if alias, ok := outerAliasMap[embed.Table.Name]; ok {
93120
embed.Table = alias
94121
continue
95122
}
@@ -195,24 +222,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
195222
panic("too many field items: " + strconv.Itoa(len(items)))
196223
}
197224

198-
search := tables
225+
search := scopeMap[ref.cteName]
199226
if alias != "" {
200227
if original, ok := aliasMap[alias]; ok {
201228
search = []*ast.TableName{original}
229+
} else if tableName, ok := tableNameMap[alias]; ok {
230+
search = []*ast.TableName{tableName}
202231
} else {
203-
var located bool
204-
for _, fqn := range tables {
205-
if fqn.Name == alias {
206-
located = true
207-
search = []*ast.TableName{fqn}
208-
}
209-
}
210-
if !located {
211-
return nil, &sqlerr.Error{
212-
Code: "42703",
213-
Message: fmt.Sprintf("table alias %q does not exist", alias),
214-
Location: node.Location,
215-
}
232+
return nil, &sqlerr.Error{
233+
Code: "42703",
234+
Message: fmt.Sprintf("table alias %q does not exist", alias),
235+
Location: node.Location,
216236
}
217237
}
218238
}
@@ -573,12 +593,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
573593
if alias != "" {
574594
if original, ok := aliasMap[alias]; ok {
575595
search = []*ast.TableName{original}
576-
} else {
577-
for _, fqn := range tables {
578-
if fqn.Name == alias {
579-
search = []*ast.TableName{fqn}
580-
}
581-
}
596+
} else if tableName, ok := tableNameMap[alias]; ok {
597+
search = []*ast.TableName{tableName}
582598
}
583599
}
584600

internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/db.go

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/cte_ambiguous_column/postgresql/stdlib/go/models.go

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)