Skip to content

Commit 6f33017

Browse files
committed
Refactored some shared functions into common.go
1 parent 82ed525 commit 6f33017

File tree

4 files changed

+46
-74
lines changed

4 files changed

+46
-74
lines changed

sql/analyzer/assign_update_join.go

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package analyzer
22

33
import (
4-
"strings"
5-
64
"github.com/dolthub/go-mysql-server/sql"
7-
"github.com/dolthub/go-mysql-server/sql/expression"
85
"github.com/dolthub/go-mysql-server/sql/plan"
96
"github.com/dolthub/go-mysql-server/sql/transform"
107
)
@@ -53,7 +50,7 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
5350

5451
// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node
5552
func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[string]sql.Node, error) {
56-
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
53+
namesOfTableToBeUpdated := plan.GetTablesToBeUpdated(node)
5754
resolvedTables := getTablesByName(ij)
5855

5956
updateTargets := make(map[string]sql.Node)
@@ -81,21 +78,3 @@ func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[strin
8178

8279
return updateTargets, nil
8380
}
84-
85-
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
86-
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
87-
ret := make(map[string]struct{})
88-
89-
transform.InspectExpressions(node, func(e sql.Expression) bool {
90-
switch e := e.(type) {
91-
case *expression.SetField:
92-
gf := e.LeftChild.(*expression.GetField)
93-
ret[strings.ToLower(gf.Table())] = struct{}{}
94-
return false
95-
}
96-
97-
return true
98-
})
99-
100-
return ret
101-
}

sql/plan/common.go

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ package plan
1616

1717
import (
1818
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/expression"
1920
"github.com/dolthub/go-mysql-server/sql/mysql_db"
2021
"github.com/dolthub/go-mysql-server/sql/transform"
22+
"strings"
2123
)
2224

2325
// IsUnary returns whether the node is unary or not.
@@ -161,29 +163,56 @@ func getTableName(nodeToSearch sql.Node) string {
161163
return ""
162164
}
163165

164-
// GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
165-
// children. Returns the first database name found, regardless of whether there are more, therefore this is only
166-
// intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
167-
// in most nodes, databases may be stored as a string field therefore there will be situations where a database name
168-
// exists on a node, but cannot be found through inspection.
169-
func GetDatabaseName(nodeToSearch sql.Node) string {
170-
nodeStack := []sql.Node{nodeToSearch}
166+
// GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
167+
func GetTablesToBeUpdated(node sql.Node) map[string]struct{} {
168+
ret := make(map[string]struct{})
169+
170+
transform.InspectExpressions(node, func(e sql.Expression) bool {
171+
switch e := e.(type) {
172+
case *expression.SetField:
173+
gf := e.LeftChild.(*expression.GetField)
174+
ret[strings.ToLower(gf.Table())] = struct{}{}
175+
return false
176+
}
177+
178+
return true
179+
})
180+
181+
return ret
182+
}
183+
184+
// GetDatabaseName attempts to fetch the database name from the node.
185+
func GetDatabaseName(node sql.Node) string {
186+
database := GetDatabase(node)
187+
if database != nil {
188+
return database.Name()
189+
}
190+
return ""
191+
}
192+
193+
// GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children.
194+
// Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in
195+
// situations where only a single database is expected to be found. Unlike how tables are handled in most nodes,
196+
// databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but
197+
// cannot be found through inspection.
198+
func GetDatabase(node sql.Node) sql.Database {
199+
nodeStack := []sql.Node{node}
171200
for len(nodeStack) > 0 {
172-
node := nodeStack[len(nodeStack)-1]
201+
n := nodeStack[len(nodeStack)-1]
173202
nodeStack = nodeStack[:len(nodeStack)-1]
174-
switch n := node.(type) {
203+
switch n := n.(type) {
175204
case sql.Databaser:
176-
return n.Database().Name()
205+
return n.Database()
177206
case *ResolvedTable:
178-
return n.SqlDatabase.Name()
207+
return n.SqlDatabase
179208
case *UnresolvedTable:
180-
return n.Database().Name()
209+
return n.Database()
181210
case *IndexedTableAccess:
182-
return n.Database().Name()
211+
return n.Database()
183212
}
184-
nodeStack = append(nodeStack, node.Children()...)
213+
nodeStack = append(nodeStack, n.Children()...)
185214
}
186-
return ""
215+
return nil
187216
}
188217

189218
// CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result

sql/plan/update.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,6 @@ func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) {
101101
}
102102
}
103103

104-
// GetDatabase returns the first database found in the node tree given
105-
func GetDatabase(node sql.Node) sql.Database {
106-
switch node := node.(type) {
107-
case *IndexedTableAccess:
108-
return GetDatabase(node.TableNode)
109-
case *ResolvedTable:
110-
return node.Database()
111-
case *UnresolvedTable:
112-
return node.Database()
113-
}
114-
115-
for _, child := range node.Children() {
116-
return GetDatabase(child)
117-
}
118-
119-
return nil
120-
}
121-
122104
// Schema implements the sql.Node interface.
123105
func (u *Update) Schema() sql.Schema {
124106
// Postgres allows the returned values of the update statement to be controlled, so if returning

sql/planbuilder/dml.go

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ func hasJoinNode(node sql.Node) bool {
593593
}
594594

595595
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
596-
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
596+
namesOfTablesToBeUpdated := plan.GetTablesToBeUpdated(node)
597597
resolvedTablesMap := getTablesByName(ij)
598598

599599
for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
@@ -662,24 +662,6 @@ func getResolvedTable(node sql.Node) *plan.ResolvedTable {
662662
return table
663663
}
664664

665-
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
666-
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
667-
ret := make(map[string]struct{})
668-
669-
transform.InspectExpressions(node, func(e sql.Expression) bool {
670-
switch e := e.(type) {
671-
case *expression.SetField:
672-
gf := e.LeftChild.(*expression.GetField)
673-
ret[gf.Table()] = struct{}{}
674-
return false
675-
}
676-
677-
return true
678-
})
679-
680-
return ret
681-
}
682-
683665
func (b *Builder) buildInto(inScope *scope, into *ast.Into) {
684666
if into.Dumpfile != "" {
685667
inScope.node = plan.NewInto(inScope.node, nil, "", into.Dumpfile)

0 commit comments

Comments
 (0)