Skip to content

Commit 020a58e

Browse files
authored
Merge pull request #3088 from dolthub/angela/refactor
[no-release-notes] Refactored some shared functions into common.go
2 parents af68faa + b7b5151 commit 020a58e

File tree

4 files changed

+47
-74
lines changed

4 files changed

+47
-74
lines changed

sql/analyzer/assign_update_join.go

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package analyzer
22

33
import (
4-
"strings"
5-
64
"github.com/dolthub/go-mysql-server/sql"
7-
"github.com/dolthub/go-mysql-server/sql/expression"
85
"github.com/dolthub/go-mysql-server/sql/plan"
96
"github.com/dolthub/go-mysql-server/sql/transform"
107
)
@@ -53,7 +50,7 @@ func modifyUpdateExprsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *
5350

5451
// getUpdateTargetsByTable maps a set of table names and aliases to their corresponding update target Node
5552
func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[string]sql.Node, error) {
56-
namesOfTableToBeUpdated := getTablesToBeUpdated(node)
53+
namesOfTableToBeUpdated := plan.GetTablesToBeUpdated(node)
5754
resolvedTables := getTablesByName(ij)
5855

5956
updateTargets := make(map[string]sql.Node)
@@ -81,21 +78,3 @@ func getUpdateTargetsByTable(node sql.Node, ij sql.Node, isJoin bool) (map[strin
8178

8279
return updateTargets, nil
8380
}
84-
85-
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
86-
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
87-
ret := make(map[string]struct{})
88-
89-
transform.InspectExpressions(node, func(e sql.Expression) bool {
90-
switch e := e.(type) {
91-
case *expression.SetField:
92-
gf := e.LeftChild.(*expression.GetField)
93-
ret[strings.ToLower(gf.Table())] = struct{}{}
94-
return false
95-
}
96-
97-
return true
98-
})
99-
100-
return ret
101-
}

sql/plan/common.go

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
package plan
1616

1717
import (
18+
"strings"
19+
1820
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
1922
"github.com/dolthub/go-mysql-server/sql/mysql_db"
2023
"github.com/dolthub/go-mysql-server/sql/transform"
2124
)
@@ -161,29 +164,56 @@ func getTableName(nodeToSearch sql.Node) string {
161164
return ""
162165
}
163166

164-
// GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
165-
// children. Returns the first database name found, regardless of whether there are more, therefore this is only
166-
// intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
167-
// in most nodes, databases may be stored as a string field therefore there will be situations where a database name
168-
// exists on a node, but cannot be found through inspection.
169-
func GetDatabaseName(nodeToSearch sql.Node) string {
170-
nodeStack := []sql.Node{nodeToSearch}
167+
// GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
168+
func GetTablesToBeUpdated(node sql.Node) map[string]struct{} {
169+
ret := make(map[string]struct{})
170+
171+
transform.InspectExpressions(node, func(e sql.Expression) bool {
172+
switch e := e.(type) {
173+
case *expression.SetField:
174+
gf := e.LeftChild.(*expression.GetField)
175+
ret[strings.ToLower(gf.Table())] = struct{}{}
176+
return false
177+
}
178+
179+
return true
180+
})
181+
182+
return ret
183+
}
184+
185+
// GetDatabaseName attempts to fetch the database name from the node.
186+
func GetDatabaseName(node sql.Node) string {
187+
database := GetDatabase(node)
188+
if database != nil {
189+
return database.Name()
190+
}
191+
return ""
192+
}
193+
194+
// GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children.
195+
// Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in
196+
// situations where only a single database is expected to be found. Unlike how tables are handled in most nodes,
197+
// databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but
198+
// cannot be found through inspection.
199+
func GetDatabase(node sql.Node) sql.Database {
200+
nodeStack := []sql.Node{node}
171201
for len(nodeStack) > 0 {
172-
node := nodeStack[len(nodeStack)-1]
202+
n := nodeStack[len(nodeStack)-1]
173203
nodeStack = nodeStack[:len(nodeStack)-1]
174-
switch n := node.(type) {
204+
switch n := n.(type) {
175205
case sql.Databaser:
176-
return n.Database().Name()
206+
return n.Database()
177207
case *ResolvedTable:
178-
return n.SqlDatabase.Name()
208+
return n.SqlDatabase
179209
case *UnresolvedTable:
180-
return n.Database().Name()
210+
return n.Database()
181211
case *IndexedTableAccess:
182-
return n.Database().Name()
212+
return n.Database()
183213
}
184-
nodeStack = append(nodeStack, node.Children()...)
214+
nodeStack = append(nodeStack, n.Children()...)
185215
}
186-
return ""
216+
return nil
187217
}
188218

189219
// CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result

sql/plan/update.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,6 @@ func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) {
101101
}
102102
}
103103

104-
// GetDatabase returns the first database found in the node tree given
105-
func GetDatabase(node sql.Node) sql.Database {
106-
switch node := node.(type) {
107-
case *IndexedTableAccess:
108-
return GetDatabase(node.TableNode)
109-
case *ResolvedTable:
110-
return node.Database()
111-
case *UnresolvedTable:
112-
return node.Database()
113-
}
114-
115-
for _, child := range node.Children() {
116-
return GetDatabase(child)
117-
}
118-
119-
return nil
120-
}
121-
122104
// Schema implements the sql.Node interface.
123105
func (u *Update) Schema() sql.Schema {
124106
// Postgres allows the returned values of the update statement to be controlled, so if returning

sql/planbuilder/dml.go

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ func hasJoinNode(node sql.Node) bool {
593593
}
594594

595595
func getResolvedTablesToUpdate(_ *sql.Context, node sql.Node, ij sql.Node) (resolvedTables []*plan.ResolvedTable, err error) {
596-
namesOfTablesToBeUpdated := getTablesToBeUpdated(node)
596+
namesOfTablesToBeUpdated := plan.GetTablesToBeUpdated(node)
597597
resolvedTablesMap := getTablesByName(ij)
598598

599599
for tableToBeUpdated, _ := range namesOfTablesToBeUpdated {
@@ -662,24 +662,6 @@ func getResolvedTable(node sql.Node) *plan.ResolvedTable {
662662
return table
663663
}
664664

665-
// getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
666-
func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
667-
ret := make(map[string]struct{})
668-
669-
transform.InspectExpressions(node, func(e sql.Expression) bool {
670-
switch e := e.(type) {
671-
case *expression.SetField:
672-
gf := e.LeftChild.(*expression.GetField)
673-
ret[gf.Table()] = struct{}{}
674-
return false
675-
}
676-
677-
return true
678-
})
679-
680-
return ret
681-
}
682-
683665
func (b *Builder) buildInto(inScope *scope, into *ast.Into) {
684666
if into.Dumpfile != "" {
685667
inScope.node = plan.NewInto(inScope.node, nil, "", into.Dumpfile)

0 commit comments

Comments
 (0)