Skip to content

Commit dac7262

Browse files
authored
allow renaming views with RENAME TABLE statement (#1712)
1 parent 8725719 commit dac7262

File tree

9 files changed

+197
-90
lines changed

9 files changed

+197
-90
lines changed

enginetest/enginetests.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,10 +1897,6 @@ func TestRecursiveViewDefinition(t *testing.T, harness Harness) {
18971897
db, err := e.Analyzer.Catalog.Database(ctx, "mydb")
18981898
require.NoError(t, err)
18991899

1900-
if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
1901-
db = pdb.Unwrap()
1902-
}
1903-
19041900
vdb, ok := db.(sql.ViewDatabase)
19051901
require.True(t, ok, "expected sql.ViewDatabase")
19061902

enginetest/queries/script_queries.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2902,6 +2902,35 @@ var ScriptTests = []ScriptTest{
29022902
},
29032903
},
29042904
},
2905+
{
2906+
Name: "rename views with RENAME TABLE ... TO .. statement",
2907+
SetUpScript: []string{
2908+
"create table t1 (id int primary key, v1 int);",
2909+
"create view v1 as select * from t1;",
2910+
},
2911+
Assertions: []ScriptTestAssertion{
2912+
{
2913+
Query: "show tables;",
2914+
Expected: []sql.Row{{"myview"}, {"t1"}, {"v1"}},
2915+
},
2916+
{
2917+
Query: "rename table v1 to view1",
2918+
Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}},
2919+
},
2920+
{
2921+
Query: "show tables;",
2922+
Expected: []sql.Row{{"myview"}, {"t1"}, {"view1"}},
2923+
},
2924+
{
2925+
Query: "rename table view1 to newViewName, t1 to newTableName",
2926+
Expected: []sql.Row{{types.OkResult{RowsAffected: 0}}},
2927+
},
2928+
{
2929+
Query: "show tables;",
2930+
Expected: []sql.Row{{"myview"}, {"newTableName"}, {"newViewName"}},
2931+
},
2932+
},
2933+
},
29052934
}
29062935

29072936
var SpatialScriptTests = []ScriptTest{

sql/analyzer/resolve_views.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"fmt"
1919

2020
"github.com/dolthub/go-mysql-server/sql"
21-
"github.com/dolthub/go-mysql-server/sql/mysql_db"
2221
"github.com/dolthub/go-mysql-server/sql/parse"
2322
"github.com/dolthub/go-mysql-server/sql/plan"
2423
"github.com/dolthub/go-mysql-server/sql/transform"
@@ -51,11 +50,7 @@ func resolveViews(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope, sel R
5150
return nil, transform.SameTree, err
5251
}
5352

54-
maybeVdb := db
55-
if privilegedDatabase, ok := maybeVdb.(mysql_db.PrivilegedDatabase); ok {
56-
maybeVdb = privilegedDatabase.Unwrap()
57-
}
58-
if vdb, vok := maybeVdb.(sql.ViewDatabase); vok {
53+
if vdb, vok := db.(sql.ViewDatabase); vok {
5954
viewDef, vdok, verr := vdb.GetViewDefinition(ctx, viewName)
6055
if verr != nil {
6156
return nil, transform.SameTree, verr

sql/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ var (
555555
// are automatically rolled back. Clients receiving this error must retry the transaction.
556556
ErrLockDeadlock = errors.NewKind("serialization failure: %s, try restarting transaction.")
557557

558+
// ErrViewsNotSupported is returned when attempting to access a view on a database that doesn't support them.
559+
ErrViewsNotSupported = errors.NewKind("database '%s' doesn't support views")
560+
558561
// ErrExistingView is returned when a CREATE VIEW statement uses a name that already exists
559562
ErrExistingView = errors.NewKind("the view %s.%s already exists")
560563

sql/information_schema/information_schema.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2793,9 +2793,6 @@ func viewsInDatabase(ctx *Context, db Database) ([]ViewDefinition, error) {
27932793
var views []ViewDefinition
27942794
dbName := db.Name()
27952795

2796-
if privilegedDatabase, ok := db.(mysql_db.PrivilegedDatabase); ok {
2797-
db = privilegedDatabase.Unwrap()
2798-
}
27992796
if vdb, ok := db.(ViewDatabase); ok {
28002797
dbViews, err := vdb.AllViews(ctx)
28012798
if err != nil {

sql/mysql_db/privileged_database_provider.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ var _ sql.TableCopierDatabase = PrivilegedDatabase{}
116116
var _ sql.ReadOnlyDatabase = PrivilegedDatabase{}
117117
var _ sql.TemporaryTableDatabase = PrivilegedDatabase{}
118118
var _ sql.CollatedDatabase = PrivilegedDatabase{}
119+
var _ sql.ViewDatabase = PrivilegedDatabase{}
119120

120121
// NewPrivilegedDatabase returns a new PrivilegedDatabase.
121122
func NewPrivilegedDatabase(grantTables *MySQLDb, db sql.Database) sql.Database {
@@ -384,6 +385,38 @@ func (pdb PrivilegedDatabase) UpdateEvent(ctx *sql.Context, ed sql.EventDefiniti
384385
return sql.ErrEventsNotSupported.New(pdb.db.Name())
385386
}
386387

388+
// CreateView implements sql.ViewDatabase
389+
func (pdb PrivilegedDatabase) CreateView(ctx *sql.Context, name string, selectStatement, createViewStmt string) error {
390+
if db, ok := pdb.db.(sql.ViewDatabase); ok {
391+
return db.CreateView(ctx, name, selectStatement, createViewStmt)
392+
}
393+
return sql.ErrViewsNotSupported.New(pdb.db.Name())
394+
}
395+
396+
// DropView implements sql.ViewDatabase
397+
func (pdb PrivilegedDatabase) DropView(ctx *sql.Context, name string) error {
398+
if db, ok := pdb.db.(sql.ViewDatabase); ok {
399+
return db.DropView(ctx, name)
400+
}
401+
return sql.ErrViewsNotSupported.New(pdb.db.Name())
402+
}
403+
404+
// GetViewDefinition implements sql.ViewDatabase
405+
func (pdb PrivilegedDatabase) GetViewDefinition(ctx *sql.Context, viewName string) (sql.ViewDefinition, bool, error) {
406+
if db, ok := pdb.db.(sql.ViewDatabase); ok {
407+
return db.GetViewDefinition(ctx, viewName)
408+
}
409+
return sql.ViewDefinition{}, false, sql.ErrViewsNotSupported.New(pdb.db.Name())
410+
}
411+
412+
// AllViews implements sql.ViewDatabase
413+
func (pdb PrivilegedDatabase) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
414+
if db, ok := pdb.db.(sql.ViewDatabase); ok {
415+
return db.AllViews(ctx)
416+
}
417+
return nil, sql.ErrViewsNotSupported.New(pdb.db.Name())
418+
}
419+
387420
// CopyTableData implements the interface sql.TableCopierDatabase.
388421
func (pdb PrivilegedDatabase) CopyTableData(ctx *sql.Context, sourceTable string, destinationTable string) (uint64, error) {
389422
if db, ok := pdb.db.(sql.TableCopierDatabase); ok {

sql/plan/alter_table.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,31 @@ func (r *RenameTable) String() string {
5353
return fmt.Sprintf("Rename table %s to %s", r.OldNames, r.NewNames)
5454
}
5555

56+
func (r *RenameTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
57+
// TODO: 'ALTER TABLE newViewName RENAME view1' should fail on renaming views - should be fixed in vitess
58+
renamer, _ := r.Db.(sql.TableRenamer)
59+
viewDb, _ := r.Db.(sql.ViewDatabase)
60+
viewRegistry := ctx.GetViewRegistry()
61+
62+
for i, oldName := range r.OldNames {
63+
if tbl, exists := r.tableExists(ctx, oldName); exists {
64+
err := r.renameTable(ctx, renamer, tbl, oldName, r.NewNames[i])
65+
if err != nil {
66+
return nil, err
67+
}
68+
} else {
69+
success, err := r.renameView(ctx, viewDb, viewRegistry, oldName, r.NewNames[i])
70+
if err != nil {
71+
return nil, err
72+
} else if !success {
73+
return nil, sql.ErrTableNotFound.New(oldName)
74+
}
75+
}
76+
}
77+
78+
return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil
79+
}
80+
5681
func (r *RenameTable) WithChildren(children ...sql.Node) (sql.Node, error) {
5782
return NillaryWithChildren(r, children...)
5883
}
@@ -74,6 +99,106 @@ func (*RenameTable) CollationCoercibility(ctx *sql.Context) (collation sql.Colla
7499
return sql.Collation_binary, 7
75100
}
76101

102+
func (r *RenameTable) tableExists(ctx *sql.Context, name string) (sql.Table, bool) {
103+
tbl, ok, err := r.Db.GetTableInsensitive(ctx, name)
104+
if err != nil || !ok {
105+
return nil, false
106+
}
107+
return tbl, true
108+
}
109+
110+
func (r *RenameTable) renameTable(ctx *sql.Context, renamer sql.TableRenamer, tbl sql.Table, oldName, newName string) error {
111+
if renamer == nil {
112+
return sql.ErrRenameTableNotSupported.New(r.Db.Name())
113+
}
114+
115+
if fkTable, ok := tbl.(sql.ForeignKeyTable); ok {
116+
parentFks, err := fkTable.GetReferencedForeignKeys(ctx)
117+
if err != nil {
118+
return err
119+
}
120+
for _, parentFk := range parentFks {
121+
//TODO: support renaming tables across databases for foreign keys
122+
if strings.ToLower(parentFk.Database) != strings.ToLower(parentFk.ParentDatabase) {
123+
return fmt.Errorf("updating foreign key table names across databases is not yet supported")
124+
}
125+
parentFk.ParentTable = newName
126+
childTbl, ok, err := r.Db.GetTableInsensitive(ctx, parentFk.Table)
127+
if err != nil {
128+
return err
129+
}
130+
if !ok {
131+
return sql.ErrTableNotFound.New(parentFk.Table)
132+
}
133+
childFkTbl, ok := childTbl.(sql.ForeignKeyTable)
134+
if !ok {
135+
return fmt.Errorf("referenced table `%s` supports foreign keys but declaring table `%s` does not", parentFk.ParentTable, parentFk.Table)
136+
}
137+
err = childFkTbl.UpdateForeignKey(ctx, parentFk.Name, parentFk)
138+
if err != nil {
139+
return err
140+
}
141+
}
142+
143+
fks, err := fkTable.GetDeclaredForeignKeys(ctx)
144+
if err != nil {
145+
return err
146+
}
147+
for _, fk := range fks {
148+
fk.Table = newName
149+
err = fkTable.UpdateForeignKey(ctx, fk.Name, fk)
150+
if err != nil {
151+
return err
152+
}
153+
}
154+
}
155+
156+
err := renamer.RenameTable(ctx, oldName, newName)
157+
if err != nil {
158+
return err
159+
}
160+
161+
return nil
162+
}
163+
164+
func (r *RenameTable) renameView(ctx *sql.Context, viewDb sql.ViewDatabase, vr *sql.ViewRegistry, oldName, newName string) (bool, error) {
165+
if viewDb != nil {
166+
oldView, exists, err := viewDb.GetViewDefinition(ctx, oldName)
167+
if err != nil {
168+
return false, err
169+
} else if !exists {
170+
return false, nil
171+
}
172+
173+
err = viewDb.DropView(ctx, oldName)
174+
if err != nil {
175+
return false, err
176+
}
177+
178+
err = viewDb.CreateView(ctx, newName, oldView.TextDefinition, oldView.CreateViewStatement)
179+
if err != nil {
180+
return false, err
181+
}
182+
183+
return true, nil
184+
} else {
185+
view, exists := vr.View(r.Db.Name(), oldName)
186+
if !exists {
187+
return false, nil
188+
}
189+
190+
err := vr.Delete(r.Db.Name(), oldName)
191+
if err != nil {
192+
return false, nil
193+
}
194+
err = vr.Register(r.Db.Name(), sql.NewView(newName, view.Definition(), view.TextDefinition(), view.CreateStatement()))
195+
if err != nil {
196+
return false, nil
197+
}
198+
return true, nil
199+
}
200+
}
201+
77202
type AddColumn struct {
78203
ddlNode
79204
Table sql.Node

sql/rowexec/ddl.go

Lines changed: 3 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -233,72 +233,7 @@ func (b *BaseBuilder) buildDropCheck(ctx *sql.Context, n *plan.DropCheck, row sq
233233
}
234234

235235
func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, row sql.Row) (sql.RowIter, error) {
236-
renamer, ok := n.Db.(sql.TableRenamer)
237-
if !ok {
238-
return nil, sql.ErrRenameTableNotSupported.New(n.Db.Name())
239-
}
240-
241-
var err error
242-
for i, oldName := range n.OldNames {
243-
var tbl sql.Table
244-
var ok bool
245-
tbl, ok, err = n.Db.GetTableInsensitive(ctx, oldName)
246-
if err != nil {
247-
return nil, err
248-
}
249-
250-
if !ok {
251-
return nil, sql.ErrTableNotFound.New(oldName)
252-
}
253-
254-
if fkTable, ok := tbl.(sql.ForeignKeyTable); ok {
255-
parentFks, err := fkTable.GetReferencedForeignKeys(ctx)
256-
if err != nil {
257-
return nil, err
258-
}
259-
for _, parentFk := range parentFks {
260-
//TODO: support renaming tables across databases for foreign keys
261-
if strings.ToLower(parentFk.Database) != strings.ToLower(parentFk.ParentDatabase) {
262-
return nil, fmt.Errorf("updating foreign key table names across databases is not yet supported")
263-
}
264-
parentFk.ParentTable = n.NewNames[i]
265-
childTbl, ok, err := n.Db.GetTableInsensitive(ctx, parentFk.Table)
266-
if err != nil {
267-
return nil, err
268-
}
269-
if !ok {
270-
return nil, sql.ErrTableNotFound.New(parentFk.Table)
271-
}
272-
childFkTbl, ok := childTbl.(sql.ForeignKeyTable)
273-
if !ok {
274-
return nil, fmt.Errorf("referenced table `%s` supports foreign keys but declaring table `%s` does not", parentFk.ParentTable, parentFk.Table)
275-
}
276-
err = childFkTbl.UpdateForeignKey(ctx, parentFk.Name, parentFk)
277-
if err != nil {
278-
return nil, err
279-
}
280-
}
281-
282-
fks, err := fkTable.GetDeclaredForeignKeys(ctx)
283-
if err != nil {
284-
return nil, err
285-
}
286-
for _, fk := range fks {
287-
fk.Table = n.NewNames[i]
288-
err = fkTable.UpdateForeignKey(ctx, fk.Name, fk)
289-
if err != nil {
290-
return nil, err
291-
}
292-
}
293-
}
294-
295-
err = renamer.RenameTable(ctx, tbl.Name(), n.NewNames[i])
296-
if err != nil {
297-
return nil, err
298-
}
299-
}
300-
301-
return sql.RowsToRowIter(sql.NewRow(types.NewOkResult(0))), nil
236+
return n.RowIter(ctx, row)
302237
}
303238

304239
func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, row sql.Row) (sql.RowIter, error) {
@@ -856,7 +791,6 @@ func (b *BaseBuilder) buildAlterDB(ctx *sql.Context, n *plan.AlterDB, row sql.Ro
856791

857792
func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, row sql.Row) (sql.RowIter, error) {
858793
var err error
859-
var vd sql.ViewDatabase
860794

861795
// If it's set to Invalid, then no collation has been explicitly defined
862796
if n.Collation == sql.Collation_Unspecified {
@@ -928,13 +862,11 @@ func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, ro
928862
return sql.RowsToRowIter(), err
929863
}
930864

931-
vd, _ = maybePrivDb.(sql.ViewDatabase)
932-
if vd != nil {
933-
_, ok, err := vd.GetViewDefinition(ctx, n.Name())
865+
if vdb, vok := n.Db.(sql.ViewDatabase); vok {
866+
_, ok, err := vdb.GetViewDefinition(ctx, n.Name())
934867
if err != nil {
935868
return nil, err
936869
}
937-
938870
if ok {
939871
return nil, sql.ErrTableAlreadyExists.New(n.Name())
940872
}

sql/rowexec/show.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,8 @@ func (b *BaseBuilder) buildShowTables(ctx *sql.Context, n *plan.ShowTables, row
193193
}
194194

195195
// TODO: currently there is no way to see views AS OF a particular time
196-
maybeVdb := n.Database()
197-
if privilegedDatabase, ok := maybeVdb.(mysql_db.PrivilegedDatabase); ok {
198-
maybeVdb = privilegedDatabase.Unwrap()
199-
}
200-
if vdb, ok := maybeVdb.(sql.ViewDatabase); ok {
196+
db := n.Database()
197+
if vdb, ok := db.(sql.ViewDatabase); ok {
201198
views, err := vdb.AllViews(ctx)
202199
if err != nil {
203200
return nil, err
@@ -211,7 +208,7 @@ func (b *BaseBuilder) buildShowTables(ctx *sql.Context, n *plan.ShowTables, row
211208
}
212209
}
213210

214-
for _, view := range ctx.GetViewRegistry().ViewsInDatabase(maybeVdb.Name()) {
211+
for _, view := range ctx.GetViewRegistry().ViewsInDatabase(db.Name()) {
215212
row := sql.Row{view.Name()}
216213
if n.Full {
217214
row = append(row, "VIEW")

0 commit comments

Comments
 (0)