diff --git a/sql/analyzer/assign_update_join.go b/sql/analyzer/assign_update_join.go index 9c7d088560..096e70d964 100644 --- a/sql/analyzer/assign_update_join.go +++ b/sql/analyzer/assign_update_join.go @@ -1,10 +1,7 @@ package analyzer import ( - "strings" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -53,7 +50,7 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope * // getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[string]sql.Node, error) { - namesOfTableToBeUpdated := getTablesToBeUpdated(node) + namesOfTableToBeUpdated := plan.GetTablesToBeUpdated(node) resolvedTables := getTablesByName(ij) updateTargets := make(map[string]sql.Node) @@ -81,21 +78,3 @@ func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[strin return updateTargets, nil } - -// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. -func getTablesToBeUpdated(node sql.Node) map[string]struct{} { - ret := make(map[string]struct{}) - - transform.InspectExpressions(node, func(e sql.Expression) bool { - switch e := e.(type) { - case *expression.SetField: - gf := e.LeftChild.(*expression.GetField) - ret[strings.ToLower(gf.Table())] = struct{}{} - return false - } - - return true - }) - - return ret -} diff --git a/sql/plan/common.go b/sql/plan/common.go index 9fabb03626..7a94f2d6b6 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -15,7 +15,10 @@ package plan import ( + "strings" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/mysql_db" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -161,29 +164,56 @@ func getTableName(nodeToSearch sql.Node) string { return "" } -// GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the -// children. Returns the first database name found, regardless of whether there are more, therefore this is only -// intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled -// in most nodes, databases may be stored as a string field therefore there will be situations where a database name -// exists on a node, but cannot be found through inspection. -func GetDatabaseName(nodeToSearch sql.Node) string { - nodeStack := []sql.Node{nodeToSearch} +// GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. +func GetTablesToBeUpdated(node sql.Node) map[string]struct{} { + ret := make(map[string]struct{}) + + transform.InspectExpressions(node, func(e sql.Expression) bool { + switch e := e.(type) { + case *expression.SetField: + gf := e.LeftChild.(*expression.GetField) + ret[strings.ToLower(gf.Table())] = struct{}{} + return false + } + + return true + }) + + return ret +} + +// GetDatabaseName attempts to fetch the database name from the node. +func GetDatabaseName(node sql.Node) string { + database := GetDatabase(node) + if database != nil { + return database.Name() + } + return "" +} + +// GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children. +// Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in +// situations where only a single database is expected to be found. Unlike how tables are handled in most nodes, +// databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but +// cannot be found through inspection. +func GetDatabase(node sql.Node) sql.Database { + nodeStack := []sql.Node{node} for len(nodeStack) > 0 { - node := nodeStack[len(nodeStack)-1] + n := nodeStack[len(nodeStack)-1] nodeStack = nodeStack[:len(nodeStack)-1] - switch n := node.(type) { + switch n := n.(type) { case sql.Databaser: - return n.Database().Name() + return n.Database() case *ResolvedTable: - return n.SqlDatabase.Name() + return n.SqlDatabase case *UnresolvedTable: - return n.Database().Name() + return n.Database() case *IndexedTableAccess: - return n.Database().Name() + return n.Database() } - nodeStack = append(nodeStack, node.Children()...) + nodeStack = append(nodeStack, n.Children()...) } - return "" + return nil } // CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result diff --git a/sql/plan/update.go b/sql/plan/update.go index 2aedd1174c..da9a5f7905 100644 --- a/sql/plan/update.go +++ b/sql/plan/update.go @@ -101,24 +101,6 @@ func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) { } } -// GetDatabase returns the first database found in the node tree given -func GetDatabase(node sql.Node) sql.Database { - switch node := node.(type) { - case *IndexedTableAccess: - return GetDatabase(node.TableNode) - case *ResolvedTable: - return node.Database() - case *UnresolvedTable: - return node.Database() - } - - for _, child := range node.Children() { - return GetDatabase(child) - } - - return nil -} - // Schema implements the sql.Node interface. func (u *Update) Schema() sql.Schema { // Postgres allows the returned values of the update statement to be controlled, so if returning diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index d7ec5053b6..1b5ea656f2 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -593,7 +593,7 @@ func hasJoinNode(node sql.Node) bool { } func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) { - namesOfTablesToBeUpdated := getTablesToBeUpdated(node) + namesOfTablesToBeUpdated := plan.GetTablesToBeUpdated(node) resolvedTablesMap := getTablesByName(ij) for tableToBeUpdated, _ := range namesOfTablesToBeUpdated { @@ -662,24 +662,6 @@ func getResolvedTable(node sql.Node) *plan.ResolvedTable { return table } -// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField. -func getTablesToBeUpdated(node sql.Node) map[string]struct{} { - ret := make(map[string]struct{}) - - transform.InspectExpressions(node, func(e sql.Expression) bool { - switch e := e.(type) { - case *expression.SetField: - gf := e.LeftChild.(*expression.GetField) - ret[gf.Table()] = struct{}{} - return false - } - - return true - }) - - return ret -} - func (b *Builder) buildInto(inScope *scope, into *ast.Into) { if into.Dumpfile != "" { inScope.node = plan.NewInto(inScope.node, nil, "", into.Dumpfile)