@@ -16,8 +16,10 @@ package plan
16
16
17
17
import (
18
18
"github.com/dolthub/go-mysql-server/sql"
19
+ "github.com/dolthub/go-mysql-server/sql/expression"
19
20
"github.com/dolthub/go-mysql-server/sql/mysql_db"
20
21
"github.com/dolthub/go-mysql-server/sql/transform"
22
+ "strings"
21
23
)
22
24
23
25
// IsUnary returns whether the node is unary or not.
@@ -161,29 +163,56 @@ func getTableName(nodeToSearch sql.Node) string {
161
163
return ""
162
164
}
163
165
164
- // GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
165
- // children. Returns the first database name found, regardless of whether there are more, therefore this is only
166
- // intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
167
- // in most nodes, databases may be stored as a string field therefore there will be situations where a database name
168
- // exists on a node, but cannot be found through inspection.
169
- func GetDatabaseName (nodeToSearch sql.Node ) string {
170
- nodeStack := []sql.Node {nodeToSearch }
166
+ // GetTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
167
+ func GetTablesToBeUpdated (node sql.Node ) map [string ]struct {} {
168
+ ret := make (map [string ]struct {})
169
+
170
+ transform .InspectExpressions (node , func (e sql.Expression ) bool {
171
+ switch e := e .(type ) {
172
+ case * expression.SetField :
173
+ gf := e .LeftChild .(* expression.GetField )
174
+ ret [strings .ToLower (gf .Table ())] = struct {}{}
175
+ return false
176
+ }
177
+
178
+ return true
179
+ })
180
+
181
+ return ret
182
+ }
183
+
184
+ // GetDatabaseName attempts to fetch the database name from the node.
185
+ func GetDatabaseName (node sql.Node ) string {
186
+ database := GetDatabase (node )
187
+ if database != nil {
188
+ return database .Name ()
189
+ }
190
+ return ""
191
+ }
192
+
193
+ // GetDatabase attempts to fetch the database from the node. If not found directly on the node, searches the children.
194
+ // Returns the first database found, regardless of whether there are more, therefore this is only intended to be used in
195
+ // situations where only a single database is expected to be found. Unlike how tables are handled in most nodes,
196
+ // databases may be stored as a string field. Therefore, there will be situations where a database exists on a node but
197
+ // cannot be found through inspection.
198
+ func GetDatabase (node sql.Node ) sql.Database {
199
+ nodeStack := []sql.Node {node }
171
200
for len (nodeStack ) > 0 {
172
- node := nodeStack [len (nodeStack )- 1 ]
201
+ n := nodeStack [len (nodeStack )- 1 ]
173
202
nodeStack = nodeStack [:len (nodeStack )- 1 ]
174
- switch n := node .(type ) {
203
+ switch n := n .(type ) {
175
204
case sql.Databaser :
176
- return n .Database (). Name ()
205
+ return n .Database ()
177
206
case * ResolvedTable :
178
- return n .SqlDatabase . Name ()
207
+ return n .SqlDatabase
179
208
case * UnresolvedTable :
180
- return n .Database (). Name ()
209
+ return n .Database ()
181
210
case * IndexedTableAccess :
182
- return n .Database (). Name ()
211
+ return n .Database ()
183
212
}
184
- nodeStack = append (nodeStack , node .Children ()... )
213
+ nodeStack = append (nodeStack , n .Children ()... )
185
214
}
186
- return ""
215
+ return nil
187
216
}
188
217
189
218
// CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result
0 commit comments