Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions sql/analyzer/assign_update_join.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package analyzer

import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
)
Expand Down Expand Up @@ -53,7 +50,7 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *

// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node
func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[string]sql.Node, error) {
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
namesOfTableToBeUpdated := plan.GetTablesToBeUpdated(node)
resolvedTables := getTablesByName(ij)

updateTargets := make(map[string]sql.Node)
Expand Down Expand Up @@ -81,21 +78,3 @@ func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[strin

return updateTargets, nil
}

// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
ret := make(map[string]struct{})

transform.InspectExpressions(node, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.SetField:
gf := e.LeftChild.(*expression.GetField)
ret[strings.ToLower(gf.Table())] = struct{}{}
return false
}

return true
})

return ret
}
60 changes: 45 additions & 15 deletions sql/plan/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package plan

import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
"github.com/dolthub/go-mysql-server/sql/transform"
)
Expand Down Expand Up @@ -161,29 +164,56 @@ func getTableName(nodeToSearch sql.Node) string {
return ""
}

// GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
// children. Returns the first database name found, regardless of whether there are more, therefore this is only
// intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
// in most nodes, databases may be stored as a string field therefore there will be situations where a database name
// exists on a node, but cannot be found through inspection.
func GetDatabaseName(nodeToSearch sql.Node) string {
nodeStack := []sql.Node{nodeToSearch}
// GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
func GetTablesToBeUpdated(node sql.Node) map[string]struct{} {
ret := make(map[string]struct{})

transform.InspectExpressions(node, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.SetField:
gf := e.LeftChild.(*expression.GetField)
ret[strings.ToLower(gf.Table())] = struct{}{}
return false
}

return true
})

return ret
}

// GetDatabaseName attempts to fetch the database name from the node.
func GetDatabaseName(node sql.Node) string {
database := GetDatabase(node)
if database != nil {
return database.Name()
}
return ""
}

// GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children.
// Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in
// situations where only a single database is expected to be found. Unlike how tables are handled in most nodes,
// databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but
// cannot be found through inspection.
func GetDatabase(node sql.Node) sql.Database {
nodeStack := []sql.Node{node}
for len(nodeStack) > 0 {
node := nodeStack[len(nodeStack)-1]
n := nodeStack[len(nodeStack)-1]
nodeStack = nodeStack[:len(nodeStack)-1]
switch n := node.(type) {
switch n := n.(type) {
case sql.Databaser:
return n.Database().Name()
return n.Database()
case *ResolvedTable:
return n.SqlDatabase.Name()
return n.SqlDatabase
case *UnresolvedTable:
return n.Database().Name()
return n.Database()
case *IndexedTableAccess:
return n.Database().Name()
return n.Database()
}
nodeStack = append(nodeStack, node.Children()...)
nodeStack = append(nodeStack, n.Children()...)
}
return ""
return nil
}

// CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result
Expand Down
18 changes: 0 additions & 18 deletions sql/plan/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,6 @@ func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) {
}
}

// GetDatabase returns the first database found in the node tree given
func GetDatabase(node sql.Node) sql.Database {
switch node := node.(type) {
case *IndexedTableAccess:
return GetDatabase(node.TableNode)
case *ResolvedTable:
return node.Database()
case *UnresolvedTable:
return node.Database()
}

for _, child := range node.Children() {
return GetDatabase(child)
}

return nil
}

// Schema implements the sql.Node interface.
func (u *Update) Schema() sql.Schema {
// Postgres allows the returned values of the update statement to be controlled, so if returning
Expand Down
20 changes: 1 addition & 19 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ func hasJoinNode(node sql.Node) bool {
}

func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
namesOfTablesToBeUpdated := plan.GetTablesToBeUpdated(node)
resolvedTablesMap := getTablesByName(ij)

for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
Expand Down Expand Up @@ -662,24 +662,6 @@ func getResolvedTable(node sql.Node) *plan.ResolvedTable {
return table
}

// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
ret := make(map[string]struct{})

transform.InspectExpressions(node, func(e sql.Expression) bool {
switch e := e.(type) {
case *expression.SetField:
gf := e.LeftChild.(*expression.GetField)
ret[gf.Table()] = struct{}{}
return false
}

return true
})

return ret
}

func (b *Builder) buildInto(inScope *scope, into *ast.Into) {
if into.Dumpfile != "" {
inScope.node = plan.NewInto(inScope.node, nil, "", into.Dumpfile)
Expand Down