Skip to content

Commit a942e4a

Browse files
[management] use readlock on add peer (#4308)
1 parent 1022a50 commit a942e4a

File tree

4 files changed

+91
-19
lines changed

4 files changed

+91
-19
lines changed

management/server/grpcserver.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
913913

914914
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
915915
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
916+
start := time.Now()
916917

917918
empty := &proto.Empty{}
918919
peerKey, err := s.parseRequest(ctx, req, empty)
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
944945

945946
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
946947

947-
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
948+
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
948949

949950
return &proto.Empty{}, nil
950951
}

management/server/peer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
609609
newPeer.DNSLabel = freeLabel
610610
newPeer.IP = freeIP
611611

612-
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
612+
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
613613
defer func() {
614614
if unlock != nil {
615615
unlock()

management/server/peer_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
14761476

14771477
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
14781478
engine := os.Getenv("NETBIRD_STORE_ENGINE")
1479-
if engine == "sqlite" || engine == "" {
1480-
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
1479+
if engine == "sqlite" || engine == "mysql" || engine == "" {
1480+
// we intentionally disabled foreign keys in mysql
1481+
t.Skip("Skipping test because store is not respecting foreign keys")
14811482
}
14821483
if runtime.GOOS == "windows" {
14831484
t.Skip("The SQLite store is not properly supported by Windows yet")

management/server/store/sql_store.go

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"gorm.io/gorm/logger"
2525

2626
nbdns "github.com/netbirdio/netbird/dns"
27+
nbcontext "github.com/netbirdio/netbird/management/server/context"
2728
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
2829
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
2930
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
7677
conns = runtime.NumCPU()
7778
}
7879

79-
if storeEngine == types.SqliteStoreEngine {
80+
switch storeEngine {
81+
case types.MysqlStoreEngine:
82+
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
83+
return nil, err
84+
}
85+
case types.SqliteStoreEngine:
8086
if err == nil {
8187
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
8288
}
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
142148
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
143149
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
144150

145-
start := time.Now()
151+
startWait := time.Now()
146152
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
147153
mtx := value.(*sync.RWMutex)
148154
mtx.Lock()
155+
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
156+
startHold := time.Now()
149157

150158
unlock = func() {
151159
mtx.Unlock()
152-
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
160+
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
153161
}
154162

155163
return unlock
@@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
159167
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
160168
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
161169

162-
start := time.Now()
170+
startWait := time.Now()
163171
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
164172
mtx := value.(*sync.RWMutex)
165173
mtx.RLock()
174+
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
175+
startHold := time.Now()
166176

167177
unlock = func() {
168178
mtx.RUnlock()
169-
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
179+
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
170180
}
171181

172182
return unlock
@@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
604614
}
605615

606616
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
617+
ctx, cancel := getDebuggingCtx(ctx)
618+
defer cancel()
619+
607620
tx := s.db
608621
if lockStrength != LockingStrengthNone {
609622
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
610623
}
611624

612625
var user types.User
613-
result := tx.Take(&user, idQueryCondition, userID)
626+
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
614627
if result.Error != nil {
615628
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
616629
return nil, status.NewUserNotFoundError(userID)
@@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
10761089
}
10771090

10781091
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
1092+
ctx, cancel := getDebuggingCtx(ctx)
1093+
defer cancel()
1094+
10791095
tx := s.db
10801096
if lockStrength != LockingStrengthNone {
10811097
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
10821098
}
10831099

10841100
var accountNetwork types.AccountNetwork
1085-
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
1101+
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
10861102
if errors.Is(err, gorm.ErrRecordNotFound) {
10871103
return nil, status.NewAccountNotFoundError(accountID)
10881104
}
@@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
10921108
}
10931109

10941110
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
1111+
ctx, cancel := getDebuggingCtx(ctx)
1112+
defer cancel()
1113+
10951114
tx := s.db
10961115
if lockStrength != LockingStrengthNone {
10971116
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
10981117
}
10991118

11001119
var peer nbpeer.Peer
1101-
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
1120+
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
11021121

11031122
if result.Error != nil {
11041123
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
11471166

11481167
// SaveUserLastLogin stores the last login time for a user in DB.
11491168
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
1169+
ctx, cancel := getDebuggingCtx(ctx)
1170+
defer cancel()
1171+
11501172
var user types.User
1151-
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
1173+
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
11521174
if result.Error != nil {
11531175
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
11541176
return status.NewUserNotFoundError(userID)
@@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
13291351
}
13301352

13311353
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
1354+
ctx, cancel := getDebuggingCtx(ctx)
1355+
defer cancel()
1356+
13321357
tx := s.db
13331358
if lockStrength != LockingStrengthNone {
13341359
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
13351360
}
13361361

13371362
var setupKey types.SetupKey
1338-
result := tx.
1363+
result := tx.WithContext(ctx).
13391364
Take(&setupKey, GetKeyQueryCondition(s), key)
13401365

13411366
if result.Error != nil {
@@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
13491374
}
13501375

13511376
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
1352-
result := s.db.Model(&types.SetupKey{}).
1377+
ctx, cancel := getDebuggingCtx(ctx)
1378+
defer cancel()
1379+
1380+
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
13531381
Where(idQueryCondition, setupKeyID).
13541382
Updates(map[string]interface{}{
13551383
"used_times": gorm.Expr("used_times + 1"),
@@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
13691397

13701398
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
13711399
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
1400+
ctx, cancel := getDebuggingCtx(ctx)
1401+
defer cancel()
1402+
13721403
var groupID string
1373-
_ = s.db.Model(types.Group{}).
1404+
_ = s.db.WithContext(ctx).Model(types.Group{}).
13741405
Select("id").
13751406
Where("account_id = ? AND name = ?", accountID, "All").
13761407
Limit(1).
@@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
13981429

13991430
// AddPeerToGroup adds a peer to a group
14001431
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
1432+
ctx, cancel := getDebuggingCtx(ctx)
1433+
defer cancel()
1434+
14011435
peer := &types.GroupPeer{
14021436
AccountID: accountID,
14031437
GroupID: groupID,
14041438
PeerID: peerID,
14051439
}
14061440

1407-
err := s.db.Clauses(clause.OnConflict{
1441+
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
14081442
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
14091443
DoNothing: true,
14101444
}).Create(peer).Error
@@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
15941628
}
15951629

15961630
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
1597-
if err := s.db.Create(peer).Error; err != nil {
1631+
ctx, cancel := getDebuggingCtx(ctx)
1632+
defer cancel()
1633+
1634+
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
15981635
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
15991636
}
16001637

@@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
17201757
}
17211758

17221759
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
1723-
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
1760+
ctx, cancel := getDebuggingCtx(ctx)
1761+
defer cancel()
1762+
1763+
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
17241764
if result.Error != nil {
17251765
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
17261766
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
27622802

27632803
return groupPeers, nil
27642804
}
2805+
2806+
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
2807+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
2808+
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
2809+
if ok {
2810+
//nolint
2811+
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
2812+
}
2813+
2814+
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
2815+
if ok {
2816+
//nolint
2817+
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
2818+
}
2819+
2820+
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
2821+
if ok {
2822+
//nolint
2823+
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
2824+
}
2825+
2826+
go func() {
2827+
select {
2828+
case <-ctx.Done():
2829+
case <-grpcCtx.Done():
2830+
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
2831+
}
2832+
}()
2833+
return ctx, cancel
2834+
}

0 commit comments

Comments
 (0)