@@ -187,11 +187,16 @@ func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignK
187
187
return rowIterWithOkResultWithZeroRowsAffected (), nil
188
188
}
189
189
190
- func (b * BaseBuilder ) buildDropTable (ctx * sql.Context , n * plan.DropTable , row sql.Row ) (sql.RowIter , error ) {
190
+ func (b * BaseBuilder ) buildDropTable (ctx * sql.Context , n * plan.DropTable , _ sql.Row ) (sql.RowIter , error ) {
191
191
var err error
192
192
var curdb sql.Database
193
193
194
- for _ , table := range n .Tables {
194
+ sortedTables , err := sortTablesByFKDependencies (ctx , n .Tables )
195
+ if err != nil {
196
+ return nil , err
197
+ }
198
+
199
+ for _ , table := range sortedTables {
195
200
tbl := table .(* plan.ResolvedTable )
196
201
curdb = tbl .SqlDatabase
197
202
@@ -255,6 +260,53 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
255
260
return rowIterWithOkResultWithZeroRowsAffected (), nil
256
261
}
257
262
263
+ // sortTablesByFKDependencies examines the specified |tableNodes| and returns a slice of sql.Table instances, sorted
264
+ // by their foreign key dependencies. Tables that have a foreign key reference to another table in the list will be
265
+ // sorted first in the list, so that foreign key constraints can be dropped in the correct order.
266
+ func sortTablesByFKDependencies (ctx * sql.Context , tableNodes []sql.Node ) (sortedTables []sql.Table , err error ) {
267
+ for _ , tableNode := range tableNodes {
268
+ table , ok := tableNode .(sql.Table )
269
+ if ! ok {
270
+ return nil , fmt .Errorf ("encountered unexpected table type `%T` during DROP TABLE" , table )
271
+ }
272
+
273
+ if fkTable , err := getForeignKeyTable (table ); err == nil {
274
+ foreignKeys , err := fkTable .GetDeclaredForeignKeys (ctx )
275
+ if err != nil {
276
+ return nil , err
277
+ }
278
+
279
+ parentTables := make (map [string ]struct {})
280
+ for _ , foreignKey := range foreignKeys {
281
+ qualifiedTableName := foreignKey .ParentTable
282
+ parentTables [qualifiedTableName ] = struct {}{}
283
+ }
284
+
285
+ inserted := false
286
+ for i , sortedTable := range sortedTables {
287
+ qualifiedTableName := sortedTable .Name ()
288
+ if _ , ok := parentTables [qualifiedTableName ]; ok {
289
+ if i == 0 {
290
+ sortedTables = append ([]sql.Table {table }, sortedTables [i :]... )
291
+ } else {
292
+ sortedTables = append (sortedTables [:i - 1 ], append ([]sql.Table {table }, sortedTables [i :]... )... )
293
+ }
294
+ inserted = true
295
+ break
296
+ }
297
+ }
298
+
299
+ if ! inserted {
300
+ sortedTables = append (sortedTables , table )
301
+ }
302
+ } else {
303
+ sortedTables = append (sortedTables , table )
304
+ }
305
+ }
306
+
307
+ return sortedTables , nil
308
+ }
309
+
258
310
func (b * BaseBuilder ) buildAlterIndex (ctx * sql.Context , n * plan.AlterIndex , row sql.Row ) (sql.RowIter , error ) {
259
311
err := b .executeAlterIndex (ctx , n )
260
312
if err != nil {
0 commit comments