Skip to content

Commit 447fbd1

Browse files
authored
Merge pull request #2890 from dolthub/fulghum/drop_tables
Sort tables by FK dependencies for `DROP TABLES`
2 parents 73b3865 + 19dfef0 commit 447fbd1

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

enginetest/queries/foreign_key_queries.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ var ForeignKeyTests = []ScriptTest{
380380
},
381381
},
382382
},
383+
{
384+
Name: "DROP TABLE, with multiple tables, sorts by foreign key dependencies",
385+
SetUpScript: []string{
386+
"create table grandparent1 (pk int primary key);",
387+
"create table parent1 (pk int primary key, c1 int references grandparent(pk));",
388+
"create table parent2 (pk int primary key);",
389+
"create table child1 (pk int primary key, c1 int, c2 int, foreign key (c1) references parent1(pk), foreign key (c2) references parent2(pk));",
390+
"create table selfref (pk int primary key, c1 int, foreign key (c1) references selfref(pk));",
391+
},
392+
Assertions: []ScriptTestAssertion{
393+
{
394+
Query: "DROP TABLE grandparent1, parent1, parent2, selfref, child1;",
395+
Expected: []sql.Row{{types.NewOkResult(0)}},
396+
},
397+
},
398+
},
383399
{
384400
Name: "Indexes used by foreign keys can't be dropped",
385401
SetUpScript: []string{

sql/rowexec/dml.go

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,16 @@ func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignK
187187
return rowIterWithOkResultWithZeroRowsAffected(), nil
188188
}
189189

190-
func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sql.Row) (sql.RowIter, error) {
190+
func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql.Row) (sql.RowIter, error) {
191191
var err error
192192
var curdb sql.Database
193193

194-
for _, table := range n.Tables {
194+
sortedTables, err := sortTablesByFKDependencies(ctx, n.Tables)
195+
if err != nil {
196+
return nil, err
197+
}
198+
199+
for _, table := range sortedTables {
195200
tbl := table.(*plan.ResolvedTable)
196201
curdb = tbl.SqlDatabase
197202

@@ -255,6 +260,53 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, row sq
255260
return rowIterWithOkResultWithZeroRowsAffected(), nil
256261
}
257262

263+
// sortTablesByFKDependencies examines the specified |tableNodes| and returns a slice of sql.Table instances, sorted
264+
// by their foreign key dependencies. Tables that have a foreign key reference to another table in the list will be
265+
// sorted first in the list, so that foreign key constraints can be dropped in the correct order.
266+
func sortTablesByFKDependencies(ctx *sql.Context, tableNodes []sql.Node) (sortedTables []sql.Table, err error) {
267+
for _, tableNode := range tableNodes {
268+
table, ok := tableNode.(sql.Table)
269+
if !ok {
270+
return nil, fmt.Errorf("encountered unexpected table type `%T` during DROP TABLE", table)
271+
}
272+
273+
if fkTable, err := getForeignKeyTable(table); err == nil {
274+
foreignKeys, err := fkTable.GetDeclaredForeignKeys(ctx)
275+
if err != nil {
276+
return nil, err
277+
}
278+
279+
parentTables := make(map[string]struct{})
280+
for _, foreignKey := range foreignKeys {
281+
qualifiedTableName := foreignKey.ParentTable
282+
parentTables[qualifiedTableName] = struct{}{}
283+
}
284+
285+
inserted := false
286+
for i, sortedTable := range sortedTables {
287+
qualifiedTableName := sortedTable.Name()
288+
if _, ok := parentTables[qualifiedTableName]; ok {
289+
if i == 0 {
290+
sortedTables = append([]sql.Table{table}, sortedTables[i:]...)
291+
} else {
292+
sortedTables = append(sortedTables[:i-1], append([]sql.Table{table}, sortedTables[i:]...)...)
293+
}
294+
inserted = true
295+
break
296+
}
297+
}
298+
299+
if !inserted {
300+
sortedTables = append(sortedTables, table)
301+
}
302+
} else {
303+
sortedTables = append(sortedTables, table)
304+
}
305+
}
306+
307+
return sortedTables, nil
308+
}
309+
258310
func (b *BaseBuilder) buildAlterIndex(ctx *sql.Context, n *plan.AlterIndex, row sql.Row) (sql.RowIter, error) {
259311
err := b.executeAlterIndex(ctx, n)
260312
if err != nil {

0 commit comments

Comments
 (0)