Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
if err := check(validate.In(c.catalog, raw)); err != nil {
return nil, err
}
rvs := rangeVars(raw.Stmt)
scopedRVs := rangeVarsWithScope(raw.Stmt)
refs, errs := findParameters(raw.Stmt)
if len(errs) > 0 {
if failfast {
Expand All @@ -173,7 +173,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool)
return nil, err
}

params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
params, err := c.resolveCatalogRefs(qc, scopedRVs, refs, namedParams, embeds)
if err := check(err); err != nil {
return nil, err
}
Expand Down
16 changes: 12 additions & 4 deletions internal/compiler/find_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ type paramRef struct {
rv *ast.RangeVar
ref *ast.ParamRef
name string // Named parameter support

cteName *string // Current CTE name, nil if not inside a CTE.
}

type paramSearch struct {
Expand All @@ -36,6 +38,8 @@ type paramSearch struct {
// XXX: Gross state hack for limit
limitCount ast.Node
limitOffset ast.Node

cteName *string // Current CTE name, nil if not inside a CTE.
}

type limitCount struct {
Expand All @@ -55,6 +59,10 @@ func (l *limitOffset) Pos() int {
func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
switch n := node.(type) {

case *ast.CommonTableExpr:
p.cteName = n.Ctename
return p

case *ast.A_Expr:
p.parent = node

Expand Down Expand Up @@ -87,7 +95,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
return p
}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName})
p.seen[ref.Location] = struct{}{}
}
for _, item := range s.ValuesLists.Items {
Expand All @@ -104,7 +112,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
*p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
return p
}
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
*p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, cteName: p.cteName})
p.seen[ref.Location] = struct{}{}
}
}
Expand All @@ -125,7 +133,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
if !ok {
continue
}
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
*p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, cteName: p.cteName})
}
p.seen[ref.Location] = struct{}{}
}
Expand Down Expand Up @@ -186,7 +194,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
}

if set {
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
*p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, cteName: p.cteName})
p.seen[n.Location] = struct{}{}
}
return nil
Expand Down
6 changes: 5 additions & 1 deletion internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
return nil, err
}

return convertColumnsToCatalog(cols), nil
}

func convertColumnsToCatalog(cols []*Column) []*catalog.Column {
catCols := make([]*catalog.Column, 0, len(cols))
for _, col := range cols {
catCols = append(catCols, &catalog.Column{
Expand All @@ -35,7 +39,7 @@ func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
Length: col.Length,
})
}
return catCols, nil
return catCols
}

func hasStarRef(cf *ast.ColumnRef) bool {
Expand Down
40 changes: 30 additions & 10 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,36 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
}, nil
}

func rangeVars(root ast.Node) []*ast.RangeVar {
var vars []*ast.RangeVar
find := astutils.VisitorFunc(func(node ast.Node) {
switch n := node.(type) {
case *ast.RangeVar:
vars = append(vars, n)
}
})
astutils.Walk(find, root)
return vars
// scopedRangeVar associates a RangeVar with a scope.
type scopedRangeVar struct {
rv *ast.RangeVar

cteName *string // Current CTE name, nil if not inside a CTE.
}

// rangeVarsWithScope collects all RangeVars with their scope.
func rangeVarsWithScope(root ast.Node) []scopedRangeVar {
var rvs []scopedRangeVar
visitor := &rvSearch{rvs: &rvs, cteName: nil}
astutils.Walk(visitor, root)
return rvs
}

// rvSearch finds all RangeVars and tracks their scope.
type rvSearch struct {
rvs *[]scopedRangeVar

cteName *string // Current CTE name, nil if not inside a CTE.
}

func (v *rvSearch) Visit(node ast.Node) astutils.Visitor {
switch n := node.(type) {
case *ast.CommonTableExpr:
return &rvSearch{rvs: v.rvs, cteName: n.Ctename}
case *ast.RangeVar:
*v.rvs = append(*v.rvs, scopedRangeVar{rv: n, cteName: v.cteName})
}
return v
}

func uniqueParamRefs(in []paramRef, dollar bool) []paramRef {
Expand Down
82 changes: 49 additions & 33 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,63 @@ func dataType(n *ast.TypeName) string {
}
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, scopedRVs []scopedRangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
c := comp.catalog

scopeMap := make(map[*string][]*ast.TableName)
outerAliasMap := map[string]*ast.TableName{}
aliasMap := map[string]*ast.TableName{}
tableNameMap := map[string]*ast.TableName{}
// TODO: Deprecate defaultTable
var defaultTable *ast.TableName
var tables []*ast.TableName

typeMap := map[string]map[string]map[string]*catalog.Column{}
indexTable := func(table catalog.Table) error {
tables = append(tables, table.Rel)

indexTableWithColumns := func(rel *ast.TableName, cols []*catalog.Column) error {
tables = append(tables, rel)
tableNameMap[rel.Name] = rel
if defaultTable == nil {
defaultTable = table.Rel
defaultTable = rel
}
schema := table.Rel.Schema
schema := rel.Schema
if schema == "" {
schema = c.DefaultSchema
}
if _, exists := typeMap[schema]; !exists {
typeMap[schema] = map[string]map[string]*catalog.Column{}
}
typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{}
for _, c := range table.Columns {
cc := c
typeMap[schema][table.Rel.Name][c.Name] = cc
typeMap[schema][rel.Name] = map[string]*catalog.Column{}
for _, c := range cols {
typeMap[schema][rel.Name][c.Name] = c
}
return nil
}

for _, rv := range rvs {
indexTable := func(table catalog.Table) error {
return indexTableWithColumns(table.Rel, table.Columns)
}

indexCTE := func(cte *Table) error {
catalogCols := convertColumnsToCatalog(cte.Columns)
return indexTableWithColumns(cte.Rel, catalogCols)
}

for _, srv := range scopedRVs {
rv := srv.rv
scope := srv.cteName
if rv.Relname == nil {
continue
}
fqn, err := ParseTableName(rv)
if err != nil {
return nil, err
}

scopeMap[scope] = append(scopeMap[scope], fqn)
if scope == nil && rv.Alias != nil {
outerAliasMap[*rv.Alias.Aliasname] = fqn
}

if _, found := aliasMap[fqn.Name]; found {
continue
}
Expand All @@ -67,9 +87,16 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}
// If the table name doesn't exist, first check if it's a CTE
if _, qcerr := qc.GetTable(fqn); qcerr != nil {
cte, qcerr := qc.GetTable(fqn)
if qcerr != nil {
return nil, err
}
if err := indexCTE(cte); err != nil {
return nil, err
}
if rv.Alias != nil {
aliasMap[*rv.Alias.Aliasname] = fqn
}
continue
}
err = indexTable(table)
Expand All @@ -89,7 +116,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
continue
}

if alias, ok := aliasMap[embed.Table.Name]; ok {
if alias, ok := outerAliasMap[embed.Table.Name]; ok {
embed.Table = alias
continue
}
Expand Down Expand Up @@ -195,24 +222,17 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
panic("too many field items: " + strconv.Itoa(len(items)))
}

search := tables
search := scopeMap[ref.cteName]
if alias != "" {
if original, ok := aliasMap[alias]; ok {
search = []*ast.TableName{original}
} else if tableName, ok := tableNameMap[alias]; ok {
search = []*ast.TableName{tableName}
} else {
var located bool
for _, fqn := range tables {
if fqn.Name == alias {
located = true
search = []*ast.TableName{fqn}
}
}
if !located {
return nil, &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("table alias %q does not exist", alias),
Location: node.Location,
}
return nil, &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("table alias %q does not exist", alias),
Location: node.Location,
}
}
}
Expand Down Expand Up @@ -573,12 +593,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
if alias != "" {
if original, ok := aliasMap[alias]; ok {
search = []*ast.TableName{original}
} else {
for _, fqn := range tables {
if fqn.Name == alias {
search = []*ast.TableName{fqn}
}
}
} else if tableName, ok := tableNameMap[alias]; ok {
search = []*ast.TableName{tableName}
}
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading