@@ -17,6 +17,7 @@ package enginetest
1717import (
1818 "context"
1919 "fmt"
20+ "github.com/dolthub/go-mysql-server/sql/mysql_db/serial"
2021 "net"
2122 "strings"
2223 "testing"
@@ -1073,7 +1074,7 @@ func TestUserPrivileges(t *testing.T, h Harness) {
10731074 defer engine .Close ()
10741075
10751076 engine .Analyzer .Catalog .MySQLDb .AddRootAccount ()
1076- engine .Analyzer .Catalog .MySQLDb .SetPersister (mysql_db.NoopPersister {})
1077+ engine .Analyzer .Catalog .MySQLDb .SetPersister (& mysql_db.NoopPersister {})
10771078 rootCtx := harness .NewContextWithClient (sql.Client {
10781079 User : "root" ,
10791080 Address : "localhost" ,
@@ -5391,6 +5392,47 @@ func TestPrepared(t *testing.T, harness Harness) {
53915392 }
53925393}
53935394
5395+ type memoryPersister struct {
5396+ users []* mysql_db.User
5397+ roles []* mysql_db.RoleEdge
5398+ }
5399+
5400+ var _ mysql_db.MySQLDbPersistence = & memoryPersister {}
5401+
5402+ func (p * memoryPersister ) CanPersist () bool {
5403+ return true
5404+ }
5405+
5406+ func (p * memoryPersister ) Persist (ctx * sql.Context , data []byte ) error {
5407+ //erase everything from users and roles
5408+ p .users = make ([]* mysql_db.User , 0 )
5409+ p .roles = make ([]* mysql_db.RoleEdge , 0 )
5410+
5411+ // Deserialize the flatbuffer
5412+ serialMySQLDb := serial .GetRootAsMySQLDb (data , 0 )
5413+
5414+ // Fill in users
5415+ for i := 0 ; i < serialMySQLDb .UserLength (); i ++ {
5416+ serialUser := new (serial.User )
5417+ if ! serialMySQLDb .User (serialUser , i ) {
5418+ continue
5419+ }
5420+ user := mysql_db .LoadUser (serialUser )
5421+ p .users = append (p .users , user )
5422+ }
5423+
5424+ // Fill in roles
5425+ for i := 0 ; i < serialMySQLDb .RoleEdgesLength (); i ++ {
5426+ serialRoleEdge := new (serial.RoleEdge )
5427+ if ! serialMySQLDb .RoleEdges (serialRoleEdge , i ) {
5428+ continue
5429+ }
5430+ role := mysql_db .LoadRoleEdge (serialRoleEdge )
5431+ p .roles = append (p .roles , role )
5432+ }
5433+ return nil
5434+ }
5435+
53945436func TestPrivilegePersistence (t * testing.T , h Harness ) {
53955437 harness , ok := h .(ClientHarness )
53965438 if ! ok {
@@ -5399,106 +5441,73 @@ func TestPrivilegePersistence(t *testing.T, h Harness) {
53995441
54005442 engine := mustNewEngine (t , harness )
54015443 defer engine .Close ()
5444+
5445+ persister := & memoryPersister {}
54025446 engine .Analyzer .Catalog .MySQLDb .AddRootAccount ()
5403- engine .Analyzer .Catalog .MySQLDb .SetPersister (& mysql_db. NoopPersister {} )
5447+ engine .Analyzer .Catalog .MySQLDb .SetPersister (persister )
54045448 ctx := NewContextWithClient (harness , sql.Client {
54055449 User : "root" ,
54065450 Address : "localhost" ,
54075451 })
54085452
5409- var users []* mysql_db.User
5410- var roles []* mysql_db.RoleEdge
5411-
5412- // TODO: do I need this?
5413- //engine.Analyzer.Catalog.MySQLDb.SetPersistCallback(
5414- // func(ctx *sql.Context, buf []byte) error {
5415- // // erase everything from users and roles
5416- // users = make([]*mysql_db.User, 0)
5417- // roles = make([]*mysql_db.RoleEdge, 0)
5418- //
5419- // // Deserialize the flatbuffer
5420- // serialMySQLDb := serial.GetRootAsMySQLDb(buf, 0)
5421- //
5422- // // Fill in users
5423- // for i := 0; i < serialMySQLDb.UserLength(); i++ {
5424- // serialUser := new(serial.User)
5425- // if !serialMySQLDb.User(serialUser, i) {
5426- // continue
5427- // }
5428- // user := mysql_db.LoadUser(serialUser)
5429- // users = append(users, user)
5430- // }
5431- //
5432- // // Fill in roles
5433- // for i := 0; i < serialMySQLDb.RoleEdgesLength(); i++ {
5434- // serialRoleEdge := new(serial.RoleEdge)
5435- // if !serialMySQLDb.RoleEdges(serialRoleEdge, i) {
5436- // continue
5437- // }
5438- // role := mysql_db.LoadRoleEdge(serialRoleEdge)
5439- // roles = append(roles, role)
5440- // }
5441- // return nil
5442- // },
5443- //)
5444-
5453+ // TODO: need to clear user table after every RunQueryWithContext
54455454 RunQueryWithContext (t , engine , harness , ctx , "CREATE USER tester@localhost" )
54465455 // If the user exists in []*mysql_db.User, then it must be NOT nil.
5447- require .NotNil (t , findUser ("tester" , "localhost" , users ))
5456+ require .NotNil (t , findUser ("tester" , "localhost" , persister . users ))
54485457
54495458 RunQueryWithContext (t , engine , harness , ctx , "INSERT INTO mysql.user (Host, User) VALUES ('localhost', 'tester1')" )
5450- require .Nil (t , findUser ("tester1" , "localhost" , users ))
5459+ require .Nil (t , findUser ("tester1" , "localhost" , persister . users ))
54515460
54525461 RunQueryWithContext (t , engine , harness , ctx , "UPDATE mysql.user SET User = 'test_user' WHERE User = 'tester'" )
5453- require .NotNil (t , findUser ("tester" , "localhost" , users ))
5462+ require .NotNil (t , findUser ("tester" , "localhost" , persister . users ))
54545463
54555464 RunQueryWithContext (t , engine , harness , ctx , "FLUSH PRIVILEGES" )
5456- require .NotNil (t , findUser ("tester1" , "localhost" , users ))
5457- require .Nil (t , findUser ("tester" , "localhost" , users ))
5458- require .NotNil (t , findUser ("test_user" , "localhost" , users ))
5465+ require .NotNil (t , findUser ("tester1" , "localhost" , persister . users ))
5466+ require .Nil (t , findUser ("tester" , "localhost" , persister . users ))
5467+ require .NotNil (t , findUser ("test_user" , "localhost" , persister . users ))
54595468
54605469 RunQueryWithContext (t , engine , harness , ctx , "DELETE FROM mysql.user WHERE User = 'tester1'" )
5461- require .NotNil (t , findUser ("tester1" , "localhost" , users ))
5470+ require .NotNil (t , findUser ("tester1" , "localhost" , persister . users ))
54625471
54635472 RunQueryWithContext (t , engine , harness , ctx , "GRANT SELECT ON mydb.* TO test_user@localhost" )
5464- user := findUser ("test_user" , "localhost" , users )
5473+ user := findUser ("test_user" , "localhost" , persister . users )
54655474 require .True (t , user .PrivilegeSet .Database ("mydb" ).Has (sql .PrivilegeType_Select ))
54665475
54675476 RunQueryWithContext (t , engine , harness , ctx , "UPDATE mysql.db SET Insert_priv = 'Y' WHERE User = 'test_user'" )
54685477 require .False (t , user .PrivilegeSet .Database ("mydb" ).Has (sql .PrivilegeType_Insert ))
54695478
54705479 RunQueryWithContext (t , engine , harness , ctx , "CREATE USER dolt@localhost" )
54715480 RunQueryWithContext (t , engine , harness , ctx , "INSERT INTO mysql.db (Host, Db, User, Select_priv) VALUES ('localhost', 'mydb', 'dolt', 'Y')" )
5472- user1 := findUser ("dolt" , "localhost" , users )
5481+ user1 := findUser ("dolt" , "localhost" , persister . users )
54735482 require .NotNil (t , user1 )
54745483 require .False (t , user1 .PrivilegeSet .Database ("mydb" ).Has (sql .PrivilegeType_Select ))
54755484
54765485 RunQueryWithContext (t , engine , harness , ctx , "FLUSH PRIVILEGES" )
5477- require .Nil (t , findUser ("tester1" , "localhost" , users ))
5478- user = findUser ("test_user" , "localhost" , users )
5486+ require .Nil (t , findUser ("tester1" , "localhost" , persister . users ))
5487+ user = findUser ("test_user" , "localhost" , persister . users )
54795488 require .True (t , user .PrivilegeSet .Database ("mydb" ).Has (sql .PrivilegeType_Insert ))
5480- user1 = findUser ("dolt" , "localhost" , users )
5489+ user1 = findUser ("dolt" , "localhost" , persister . users )
54815490 require .True (t , user1 .PrivilegeSet .Database ("mydb" ).Has (sql .PrivilegeType_Select ))
54825491
54835492 RunQueryWithContext (t , engine , harness , ctx , "CREATE ROLE test_role" )
54845493 RunQueryWithContext (t , engine , harness , ctx , "GRANT SELECT ON *.* TO test_role" )
5485- require .Zero (t , len (roles ))
5494+ require .Zero (t , len (persister . roles ))
54865495 RunQueryWithContext (t , engine , harness , ctx , "GRANT test_role TO test_user@localhost" )
5487- require .NotZero (t , len (roles ))
5496+ require .NotZero (t , len (persister . roles ))
54885497
54895498 RunQueryWithContext (t , engine , harness , ctx , "UPDATE mysql.role_edges SET to_user = 'tester2' WHERE to_user = 'test_user'" )
5490- require .NotNil (t , findRole ("test_user" , roles ))
5491- require .Nil (t , findRole ("tester2" , roles ))
5499+ require .NotNil (t , findRole ("test_user" , persister . roles ))
5500+ require .Nil (t , findRole ("tester2" , persister . roles ))
54925501
54935502 RunQueryWithContext (t , engine , harness , ctx , "FLUSH PRIVILEGES" )
5494- require .Nil (t , findRole ("test_user" , roles ))
5495- require .NotNil (t , findRole ("tester2" , roles ))
5503+ require .Nil (t , findRole ("test_user" , persister . roles ))
5504+ require .NotNil (t , findRole ("tester2" , persister . roles ))
54965505
54975506 RunQueryWithContext (t , engine , harness , ctx , "INSERT INTO mysql.role_edges VALUES ('%', 'test_role', 'localhost', 'test_user', 'N')" )
5498- require .Nil (t , findRole ("test_user" , roles ))
5507+ require .Nil (t , findRole ("test_user" , persister . roles ))
54995508
55005509 RunQueryWithContext (t , engine , harness , ctx , "FLUSH PRIVILEGES" )
5501- require .NotNil (t , findRole ("test_user" , roles ))
5510+ require .NotNil (t , findRole ("test_user" , persister . roles ))
55025511
55035512 _ , _ , err := engine .Query (ctx , "FLUSH NO_WRITE_TO_BINLOG PRIVILEGES" )
55045513 require .Error (t , err )
0 commit comments