@@ -187,11 +187,16 @@ func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignK
187187 return rowIterWithOkResultWithZeroRowsAffected (), nil
188188}
189189
190- func (b * BaseBuilder ) buildDropTable (ctx * sql.Context , n * plan.DropTable , row sql.Row ) (sql.RowIter , error ) {
190+ func (b * BaseBuilder ) buildDropTable (ctx * sql.Context , n * plan.DropTable , _ sql.Row ) (sql.RowIter , error ) {
191191 var err error
192192 var curdb sql.Database
193193
194- for _ , table := range n .Tables {
194+ sortedTables , err := sortTablesByFKDependencies (ctx , n .Tables )
195+ if err != nil {
196+ return nil , err
197+ }
198+
199+ for _ , table := range sortedTables {
195200 tbl := table .(* plan.ResolvedTable )
196201 curdb = tbl .SqlDatabase
197202
@@ -255,6 +260,53 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
255260 return rowIterWithOkResultWithZeroRowsAffected (), nil
256261}
257262
263+ // sortTablesByFKDependencies examines the specified |tableNodes| and returns a slice of sql.Table instances, sorted
264+ // by their foreign key dependencies. Tables that have a foreign key reference to another table in the list will be
265+ // sorted first in the list, so that foreign key constraints can be dropped in the correct order.
266+ func sortTablesByFKDependencies (ctx * sql.Context , tableNodes []sql.Node ) (sortedTables []sql.Table , err error ) {
267+ for _ , tableNode := range tableNodes {
268+ table , ok := tableNode .(sql.Table )
269+ if ! ok {
270+ return nil , fmt .Errorf ("encountered unexpected table type `%T` during DROP TABLE" , table )
271+ }
272+
273+ if fkTable , err := getForeignKeyTable (table ); err == nil {
274+ foreignKeys , err := fkTable .GetDeclaredForeignKeys (ctx )
275+ if err != nil {
276+ return nil , err
277+ }
278+
279+ parentTables := make (map [string ]struct {})
280+ for _ , foreignKey := range foreignKeys {
281+ qualifiedTableName := foreignKey .ParentTable
282+ parentTables [qualifiedTableName ] = struct {}{}
283+ }
284+
285+ inserted := false
286+ for i , sortedTable := range sortedTables {
287+ qualifiedTableName := sortedTable .Name ()
288+ if _ , ok := parentTables [qualifiedTableName ]; ok {
289+ if i == 0 {
290+ sortedTables = append ([]sql.Table {table }, sortedTables [i :]... )
291+ } else {
292+ sortedTables = append (sortedTables [:i - 1 ], append ([]sql.Table {table }, sortedTables [i :]... )... )
293+ }
294+ inserted = true
295+ break
296+ }
297+ }
298+
299+ if ! inserted {
300+ sortedTables = append (sortedTables , table )
301+ }
302+ } else {
303+ sortedTables = append (sortedTables , table )
304+ }
305+ }
306+
307+ return sortedTables , nil
308+ }
309+
258310func (b * BaseBuilder ) buildAlterIndex (ctx * sql.Context , n * plan.AlterIndex , row sql.Row ) (sql.RowIter , error ) {
259311 err := b .executeAlterIndex (ctx , n )
260312 if err != nil {
0 commit comments