Skip to content

Commit 97a1fef

Browse files
authored
Merge pull request #1049 from dolthub/james/mysql
[no-release-notes] add persist interface
2 parents 37da4c1 + 1e6a1d3 commit 97a1fef

File tree

5 files changed

+111
-67
lines changed

5 files changed

+111
-67
lines changed

enginetest/enginetests.go

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ func TestUserPrivileges(t *testing.T, h Harness) {
10141014
Address: "localhost",
10151015
})
10161016
engine.Analyzer.Catalog.MySQLDb.AddRootAccount()
1017+
engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{})
10171018

10181019
for _, statement := range script.SetUpScript {
10191020
if sh, ok := harness.(SkippingHarness); ok {
@@ -1073,6 +1074,7 @@ func TestUserPrivileges(t *testing.T, h Harness) {
10731074
defer engine.Close()
10741075

10751076
engine.Analyzer.Catalog.MySQLDb.AddRootAccount()
1077+
engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{})
10761078
rootCtx := harness.NewContextWithClient(sql.Client{
10771079
User: "root",
10781080
Address: "localhost",
@@ -1173,6 +1175,7 @@ func TestUserAuthentication(t *testing.T, h Harness) {
11731175
engine := mustNewEngine(t, harness)
11741176
defer engine.Close()
11751177
engine.Analyzer.Catalog.MySQLDb.AddRootAccount()
1178+
engine.Analyzer.Catalog.MySQLDb.SetPersister(&mysql_db.NoopPersister{})
11761179
if script.SetUpFunc != nil {
11771180
script.SetUpFunc(ctx, t, engine)
11781181
}
@@ -5389,6 +5392,47 @@ func TestPrepared(t *testing.T, harness Harness) {
53895392
}
53905393
}
53915394

5395+
type memoryPersister struct {
5396+
users []*mysql_db.User
5397+
roles []*mysql_db.RoleEdge
5398+
}
5399+
5400+
var _ mysql_db.MySQLDbPersistence = &memoryPersister{}
5401+
5402+
func (p *memoryPersister) ValidateCanPersist() error {
5403+
return nil
5404+
}
5405+
5406+
func (p *memoryPersister) Persist(ctx *sql.Context, data []byte) error {
5407+
//erase everything from users and roles
5408+
p.users = make([]*mysql_db.User, 0)
5409+
p.roles = make([]*mysql_db.RoleEdge, 0)
5410+
5411+
// Deserialize the flatbuffer
5412+
serialMySQLDb := serial.GetRootAsMySQLDb(data, 0)
5413+
5414+
// Fill in users
5415+
for i := 0; i < serialMySQLDb.UserLength(); i++ {
5416+
serialUser := new(serial.User)
5417+
if !serialMySQLDb.User(serialUser, i) {
5418+
continue
5419+
}
5420+
user := mysql_db.LoadUser(serialUser)
5421+
p.users = append(p.users, user)
5422+
}
5423+
5424+
// Fill in roles
5425+
for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ {
5426+
serialRoleEdge := new(serial.RoleEdge)
5427+
if !serialMySQLDb.RoleEdges(serialRoleEdge, i) {
5428+
continue
5429+
}
5430+
role := mysql_db.LoadRoleEdge(serialRoleEdge)
5431+
p.roles = append(p.roles, role)
5432+
}
5433+
return nil
5434+
}
5435+
53925436
func TestPrivilegePersistence(t *testing.T, h Harness) {
53935437
harness, ok := h.(ClientHarness)
53945438
if !ok {
@@ -5397,103 +5441,72 @@ func TestPrivilegePersistence(t *testing.T, h Harness) {
53975441

53985442
engine := mustNewEngine(t, harness)
53995443
defer engine.Close()
5444+
5445+
persister := &memoryPersister{}
54005446
engine.Analyzer.Catalog.MySQLDb.AddRootAccount()
5447+
engine.Analyzer.Catalog.MySQLDb.SetPersister(persister)
54015448
ctx := NewContextWithClient(harness, sql.Client{
54025449
User: "root",
54035450
Address: "localhost",
54045451
})
54055452

5406-
var users []*mysql_db.User
5407-
var roles []*mysql_db.RoleEdge
5408-
engine.Analyzer.Catalog.MySQLDb.SetPersistCallback(
5409-
func(ctx *sql.Context, buf []byte) error {
5410-
// erase everything from users and roles
5411-
users = make([]*mysql_db.User, 0)
5412-
roles = make([]*mysql_db.RoleEdge, 0)
5413-
5414-
// Deserialize the flatbuffer
5415-
serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0)
5416-
5417-
// Fill in users
5418-
for i := 0; i < serialMySQLDb.UserLength(); i++ {
5419-
serialUser := new(serial.User)
5420-
if !serialMySQLDb.User(serialUser, i) {
5421-
continue
5422-
}
5423-
user := mysql_db.LoadUser(serialUser)
5424-
users = append(users, user)
5425-
}
5426-
5427-
// Fill in roles
5428-
for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ {
5429-
serialRoleEdge := new(serial.RoleEdge)
5430-
if !serialMySQLDb.RoleEdges(serialRoleEdge, i) {
5431-
continue
5432-
}
5433-
role := mysql_db.LoadRoleEdge(serialRoleEdge)
5434-
roles = append(roles, role)
5435-
}
5436-
return nil
5437-
},
5438-
)
5439-
54405453
RunQueryWithContext(t, engine, harness, ctx, "CREATE USER tester@localhost")
54415454
// If the user exists in []*mysql_db.User, then it must be NOT nil.
5442-
require.NotNil(t, findUser("tester", "localhost", users))
5455+
require.NotNil(t, findUser("tester", "localhost", persister.users))
54435456

54445457
RunQueryWithContext(t, engine, harness, ctx, "INSERT INTO mysql.user (Host, User) VALUES ('localhost', 'tester1')")
5445-
require.Nil(t, findUser("tester1", "localhost", users))
5458+
require.Nil(t, findUser("tester1", "localhost", persister.users))
54465459

54475460
RunQueryWithContext(t, engine, harness, ctx, "UPDATE mysql.user SET User = 'test_user' WHERE User = 'tester'")
5448-
require.NotNil(t, findUser("tester", "localhost", users))
5461+
require.NotNil(t, findUser("tester", "localhost", persister.users))
54495462

54505463
RunQueryWithContext(t, engine, harness, ctx, "FLUSH PRIVILEGES")
5451-
require.NotNil(t, findUser("tester1", "localhost", users))
5452-
require.Nil(t, findUser("tester", "localhost", users))
5453-
require.NotNil(t, findUser("test_user", "localhost", users))
5464+
require.NotNil(t, findUser("tester1", "localhost", persister.users))
5465+
require.Nil(t, findUser("tester", "localhost", persister.users))
5466+
require.NotNil(t, findUser("test_user", "localhost", persister.users))
54545467

54555468
RunQueryWithContext(t, engine, harness, ctx, "DELETE FROM mysql.user WHERE User = 'tester1'")
5456-
require.NotNil(t, findUser("tester1", "localhost", users))
5469+
require.NotNil(t, findUser("tester1", "localhost", persister.users))
54575470

54585471
RunQueryWithContext(t, engine, harness, ctx, "GRANT SELECT ON mydb.* TO test_user@localhost")
5459-
user := findUser("test_user", "localhost", users)
5472+
user := findUser("test_user", "localhost", persister.users)
54605473
require.True(t, user.PrivilegeSet.Database("mydb").Has(sql.PrivilegeType_Select))
54615474

54625475
RunQueryWithContext(t, engine, harness, ctx, "UPDATE mysql.db SET Insert_priv = 'Y' WHERE User = 'test_user'")
54635476
require.False(t, user.PrivilegeSet.Database("mydb").Has(sql.PrivilegeType_Insert))
54645477

54655478
RunQueryWithContext(t, engine, harness, ctx, "CREATE USER dolt@localhost")
54665479
RunQueryWithContext(t, engine, harness, ctx, "INSERT INTO mysql.db (Host, Db, User, Select_priv) VALUES ('localhost', 'mydb', 'dolt', 'Y')")
5467-
user1 := findUser("dolt", "localhost", users)
5480+
user1 := findUser("dolt", "localhost", persister.users)
54685481
require.NotNil(t, user1)
54695482
require.False(t, user1.PrivilegeSet.Database("mydb").Has(sql.PrivilegeType_Select))
54705483

54715484
RunQueryWithContext(t, engine, harness, ctx, "FLUSH PRIVILEGES")
5472-
require.Nil(t, findUser("tester1", "localhost", users))
5473-
user = findUser("test_user", "localhost", users)
5485+
require.Nil(t, findUser("tester1", "localhost", persister.users))
5486+
user = findUser("test_user", "localhost", persister.users)
54745487
require.True(t, user.PrivilegeSet.Database("mydb").Has(sql.PrivilegeType_Insert))
5475-
user1 = findUser("dolt", "localhost", users)
5488+
user1 = findUser("dolt", "localhost", persister.users)
54765489
require.True(t, user1.PrivilegeSet.Database("mydb").Has(sql.PrivilegeType_Select))
54775490

54785491
RunQueryWithContext(t, engine, harness, ctx, "CREATE ROLE test_role")
54795492
RunQueryWithContext(t, engine, harness, ctx, "GRANT SELECT ON *.* TO test_role")
5480-
require.Zero(t, len(roles))
5493+
require.Zero(t, len(persister.roles))
54815494
RunQueryWithContext(t, engine, harness, ctx, "GRANT test_role TO test_user@localhost")
5482-
require.NotZero(t, len(roles))
5495+
require.NotZero(t, len(persister.roles))
54835496

54845497
RunQueryWithContext(t, engine, harness, ctx, "UPDATE mysql.role_edges SET to_user = 'tester2' WHERE to_user = 'test_user'")
5485-
require.NotNil(t, findRole("test_user", roles))
5486-
require.Nil(t, findRole("tester2", roles))
5498+
require.NotNil(t, findRole("test_user", persister.roles))
5499+
require.Nil(t, findRole("tester2", persister.roles))
54875500

54885501
RunQueryWithContext(t, engine, harness, ctx, "FLUSH PRIVILEGES")
5489-
require.Nil(t, findRole("test_user", roles))
5490-
require.NotNil(t, findRole("tester2", roles))
5502+
require.Nil(t, findRole("test_user", persister.roles))
5503+
require.NotNil(t, findRole("tester2", persister.roles))
54915504

54925505
RunQueryWithContext(t, engine, harness, ctx, "INSERT INTO mysql.role_edges VALUES ('%', 'test_role', 'localhost', 'test_user', 'N')")
5493-
require.Nil(t, findRole("test_user", roles))
5506+
require.Nil(t, findRole("test_user", persister.roles))
54945507

54955508
RunQueryWithContext(t, engine, harness, ctx, "FLUSH PRIVILEGES")
5496-
require.NotNil(t, findRole("test_user", roles))
5509+
require.NotNil(t, findRole("test_user", persister.roles))
54975510

54985511
_, _, err := engine.Query(ctx, "FLUSH NO_WRITE_TO_BINLOG PRIVILEGES")
54995512
require.Error(t, err)

sql/mysql_db/mysql_db.go

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,26 @@ import (
3030
"github.com/dolthub/go-mysql-server/sql/mysql_db/serial"
3131
)
3232

33-
// PersistCallback represents the callback that will be called when the Grant Tables have been updated and need to be
34-
// persisted.
35-
type PersistCallback func(ctx *sql.Context, data []byte) error
33+
// MySQLDbPersistence is used to determine the behavior of how certain tables in MySQLDb will be persisted.
34+
type MySQLDbPersistence interface {
35+
ValidateCanPersist() error
36+
Persist(ctx *sql.Context, data []byte) error
37+
}
38+
39+
// NoopPersister is used when nothing in mysql db should be persisted
40+
type NoopPersister struct{}
41+
42+
var _ MySQLDbPersistence = &NoopPersister{}
43+
44+
// CanPersist implements the MySQLDbPersistence interface
45+
func (p *NoopPersister) ValidateCanPersist() error {
46+
return nil
47+
}
48+
49+
// Persist implements the MySQLDbPersistence interface
50+
func (p *NoopPersister) Persist(ctx *sql.Context, data []byte) error {
51+
return nil
52+
}
3653

3754
// MySQLDb are the collection of tables that are in the MySQL database
3855
type MySQLDb struct {
@@ -50,7 +67,7 @@ type MySQLDb struct {
5067
//default_roles *mysqlTable
5168
//password_history *mysqlTable
5269

53-
persistFunc PersistCallback
70+
persister MySQLDbPersistence
5471
}
5572

5673
var _ sql.Database = (*MySQLDb)(nil)
@@ -136,9 +153,9 @@ func (t *MySQLDb) LoadData(ctx *sql.Context, buf []byte) error {
136153
return nil
137154
}
138155

139-
// SetPersistCallback sets the callback to be used when the MySQL Db tables have been updated and need to be persisted.
140-
func (t *MySQLDb) SetPersistCallback(persistFunc PersistCallback) {
141-
t.persistFunc = persistFunc
156+
// SetPersister sets the custom persister to be used when the MySQL Db tables have been updated and need to be persisted.
157+
func (t *MySQLDb) SetPersister(persister MySQLDbPersistence) {
158+
t.persister = persister
142159
}
143160

144161
// AddRootAccount adds the root account to the list of accounts.
@@ -331,13 +348,13 @@ func (t *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.Ge
331348
return nil, fmt.Errorf(`the only user login interface currently supported is "mysql_native_password"`)
332349
}
333350

351+
// CanPersist calls the persister's CanPersist method
352+
func (t *MySQLDb) ValidateCanPersist() error {
353+
return t.persister.ValidateCanPersist()
354+
}
355+
334356
// Persist passes along all changes to the integrator.
335357
func (t *MySQLDb) Persist(ctx *sql.Context) error {
336-
// Do nothing if persist function is nil
337-
if t.persistFunc == nil {
338-
return nil
339-
}
340-
341358
// Extract all user entries from table, and sort
342359
userEntries := t.user.data.ToSlice(ctx)
343360
users := make([]*User, len(userEntries))
@@ -387,7 +404,7 @@ func (t *MySQLDb) Persist(ctx *sql.Context) error {
387404
b.Finish(mysqlDbOffset)
388405

389406
// Persist data
390-
return t.persistFunc(ctx, b.FinishedBytes())
407+
return t.persister.Persist(ctx, b.FinishedBytes())
391408
}
392409

393410
// UserTable returns the "user" table.

sql/plan/create_role.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ func (n *CreateRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)
106106
if !ok {
107107
return nil, sql.ErrDatabaseNotFound.New("mysql")
108108
}
109+
110+
// Check if you can even persist in the first place
111+
if err := mysqlDb.ValidateCanPersist(); err != nil {
112+
return nil, err
113+
}
114+
109115
userTableData := mysqlDb.UserTable().Data()
110116
for _, role := range n.Roles {
111117
userPk := mysql_db.UserPrimaryKey{

sql/plan/create_user.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ func (n *CreateUser) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)
100100
if !ok {
101101
return nil, sql.ErrDatabaseNotFound.New("mysql")
102102
}
103+
// Check if you can even persist in the first place
104+
if err := mysqlDb.ValidateCanPersist(); err != nil {
105+
return nil, err
106+
}
103107
userTableData := mysqlDb.UserTable().Data()
104108
for _, user := range n.Users {
105109
userPk := mysql_db.UserPrimaryKey{

sql/plan/grant.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ func (n *Grant) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
201201
if !ok {
202202
return nil, sql.ErrDatabaseNotFound.New("mysql")
203203
}
204+
// Check if you can even persist in the first place
205+
if err := mysqlDb.ValidateCanPersist(); err != nil {
206+
return nil, err
207+
}
204208
if n.PrivilegeLevel.Database == "*" && n.PrivilegeLevel.TableRoutine == "*" {
205209
if n.ObjectType != ObjectType_Any {
206210
return nil, sql.ErrGrantRevokeIllegalPrivilege.New()

0 commit comments

Comments
 (0)