@@ -16,8 +16,10 @@ package plan
1616
1717import (
1818 "github.com/dolthub/go-mysql-server/sql"
19+ "github.com/dolthub/go-mysql-server/sql/expression"
1920 "github.com/dolthub/go-mysql-server/sql/mysql_db"
2021 "github.com/dolthub/go-mysql-server/sql/transform"
22+ "strings"
2123)
2224
2325// IsUnary returns whether the node is unary or not.
@@ -161,29 +163,56 @@ func getTableName(nodeToSearch sql.Node) string {
161163 return ""
162164}
163165
164- // GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
165- // children. Returns the first database name found, regardless of whether there are more, therefore this is only
166- // intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
167- // in most nodes, databases may be stored as a string field therefore there will be situations where a database name
168- // exists on a node, but cannot be found through inspection.
169- func GetDatabaseName (nodeToSearch sql.Node ) string {
170- nodeStack := []sql.Node {nodeToSearch }
166+ // GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
167+ func GetTablesToBeUpdated (node sql.Node ) map [string ]struct {} {
168+ ret := make (map [string ]struct {})
169+
170+ transform .InspectExpressions (node , func (e sql.Expression ) bool {
171+ switch e := e .(type ) {
172+ case * expression.SetField :
173+ gf := e .LeftChild .(* expression.GetField )
174+ ret [strings .ToLower (gf .Table ())] = struct {}{}
175+ return false
176+ }
177+
178+ return true
179+ })
180+
181+ return ret
182+ }
183+
184+ // GetDatabaseName attempts to fetch the database name from the node.
185+ func GetDatabaseName (node sql.Node ) string {
186+ database := GetDatabase (node )
187+ if database != nil {
188+ return database .Name ()
189+ }
190+ return ""
191+ }
192+
193+ // GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children.
194+ // Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in
195+ // situations where only a single database is expected to be found. Unlike how tables are handled in most nodes,
196+ // databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but
197+ // cannot be found through inspection.
198+ func GetDatabase (node sql.Node ) sql.Database {
199+ nodeStack := []sql.Node {node }
171200 for len (nodeStack ) > 0 {
172- node := nodeStack [len (nodeStack )- 1 ]
201+ n := nodeStack [len (nodeStack )- 1 ]
173202 nodeStack = nodeStack [:len (nodeStack )- 1 ]
174- switch n := node .(type ) {
203+ switch n := n .(type ) {
175204 case sql.Databaser :
176- return n .Database (). Name ()
205+ return n .Database ()
177206 case * ResolvedTable :
178- return n .SqlDatabase . Name ()
207+ return n .SqlDatabase
179208 case * UnresolvedTable :
180- return n .Database (). Name ()
209+ return n .Database ()
181210 case * IndexedTableAccess :
182- return n .Database (). Name ()
211+ return n .Database ()
183212 }
184- nodeStack = append (nodeStack , node .Children ()... )
213+ nodeStack = append (nodeStack , n .Children ()... )
185214 }
186- return ""
215+ return nil
187216}
188217
189218// CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result
0 commit comments