@@ -24,6 +24,7 @@ import (
24
24
"gorm.io/gorm/logger"
25
25
26
26
nbdns "github.com/netbirdio/netbird/dns"
27
+ nbcontext "github.com/netbirdio/netbird/management/server/context"
27
28
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
28
29
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
29
30
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
76
77
conns = runtime .NumCPU ()
77
78
}
78
79
79
- if storeEngine == types .SqliteStoreEngine {
80
+ switch storeEngine {
81
+ case types .MysqlStoreEngine :
82
+ if err := db .Exec ("SET GLOBAL FOREIGN_KEY_CHECKS = 0" ).Error ; err != nil {
83
+ return nil , err
84
+ }
85
+ case types .SqliteStoreEngine :
80
86
if err == nil {
81
87
log .WithContext (ctx ).Warnf ("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1" )
82
88
}
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
142
148
func (s * SqlStore ) AcquireWriteLockByUID (ctx context.Context , uniqueID string ) (unlock func ()) {
143
149
log .WithContext (ctx ).Tracef ("acquiring write lock for ID %s" , uniqueID )
144
150
145
- start := time .Now ()
151
+ startWait := time .Now ()
146
152
value , _ := s .resourceLocks .LoadOrStore (uniqueID , & sync.RWMutex {})
147
153
mtx := value .(* sync.RWMutex )
148
154
mtx .Lock ()
155
+ log .WithContext (ctx ).Tracef ("waiting to acquire write lock for ID %s in %v" , uniqueID , time .Since (startWait ))
156
+ startHold := time .Now ()
149
157
150
158
unlock = func () {
151
159
mtx .Unlock ()
152
- log .WithContext (ctx ).Tracef ("released write lock for ID %s in %v" , uniqueID , time .Since (start ))
160
+ log .WithContext (ctx ).Tracef ("released write lock for ID %s in %v" , uniqueID , time .Since (startHold ))
153
161
}
154
162
155
163
return unlock
@@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
159
167
func (s * SqlStore ) AcquireReadLockByUID (ctx context.Context , uniqueID string ) (unlock func ()) {
160
168
log .WithContext (ctx ).Tracef ("acquiring read lock for ID %s" , uniqueID )
161
169
162
- start := time .Now ()
170
+ startWait := time .Now ()
163
171
value , _ := s .resourceLocks .LoadOrStore (uniqueID , & sync.RWMutex {})
164
172
mtx := value .(* sync.RWMutex )
165
173
mtx .RLock ()
174
+ log .WithContext (ctx ).Tracef ("waiting to acquire read lock for ID %s in %v" , uniqueID , time .Since (startWait ))
175
+ startHold := time .Now ()
166
176
167
177
unlock = func () {
168
178
mtx .RUnlock ()
169
- log .WithContext (ctx ).Tracef ("released read lock for ID %s in %v" , uniqueID , time .Since (start ))
179
+ log .WithContext (ctx ).Tracef ("released read lock for ID %s in %v" , uniqueID , time .Since (startHold ))
170
180
}
171
181
172
182
return unlock
@@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
604
614
}
605
615
606
616
func (s * SqlStore ) GetUserByUserID (ctx context.Context , lockStrength LockingStrength , userID string ) (* types.User , error ) {
617
+ ctx , cancel := getDebuggingCtx (ctx )
618
+ defer cancel ()
619
+
607
620
tx := s .db
608
621
if lockStrength != LockingStrengthNone {
609
622
tx = tx .Clauses (clause.Locking {Strength : string (lockStrength )})
610
623
}
611
624
612
625
var user types.User
613
- result := tx .Take (& user , idQueryCondition , userID )
626
+ result := tx .WithContext ( ctx ). Take (& user , idQueryCondition , userID )
614
627
if result .Error != nil {
615
628
if errors .Is (result .Error , gorm .ErrRecordNotFound ) {
616
629
return nil , status .NewUserNotFoundError (userID )
@@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
1076
1089
}
1077
1090
1078
1091
func (s * SqlStore ) GetAccountNetwork (ctx context.Context , lockStrength LockingStrength , accountID string ) (* types.Network , error ) {
1092
+ ctx , cancel := getDebuggingCtx (ctx )
1093
+ defer cancel ()
1094
+
1079
1095
tx := s .db
1080
1096
if lockStrength != LockingStrengthNone {
1081
1097
tx = tx .Clauses (clause.Locking {Strength : string (lockStrength )})
1082
1098
}
1083
1099
1084
1100
var accountNetwork types.AccountNetwork
1085
- if err := tx .Model (& types.Account {}).Where (idQueryCondition , accountID ).Take (& accountNetwork ).Error ; err != nil {
1101
+ if err := tx .WithContext ( ctx ). Model (& types.Account {}).Where (idQueryCondition , accountID ).Take (& accountNetwork ).Error ; err != nil {
1086
1102
if errors .Is (err , gorm .ErrRecordNotFound ) {
1087
1103
return nil , status .NewAccountNotFoundError (accountID )
1088
1104
}
@@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
1092
1108
}
1093
1109
1094
1110
func (s * SqlStore ) GetPeerByPeerPubKey (ctx context.Context , lockStrength LockingStrength , peerKey string ) (* nbpeer.Peer , error ) {
1111
+ ctx , cancel := getDebuggingCtx (ctx )
1112
+ defer cancel ()
1113
+
1095
1114
tx := s .db
1096
1115
if lockStrength != LockingStrengthNone {
1097
1116
tx = tx .Clauses (clause.Locking {Strength : string (lockStrength )})
1098
1117
}
1099
1118
1100
1119
var peer nbpeer.Peer
1101
- result := tx .Take (& peer , GetKeyQueryCondition (s ), peerKey )
1120
+ result := tx .WithContext ( ctx ). Take (& peer , GetKeyQueryCondition (s ), peerKey )
1102
1121
1103
1122
if result .Error != nil {
1104
1123
if errors .Is (result .Error , gorm .ErrRecordNotFound ) {
@@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
1147
1166
1148
1167
// SaveUserLastLogin stores the last login time for a user in DB.
1149
1168
func (s * SqlStore ) SaveUserLastLogin (ctx context.Context , accountID , userID string , lastLogin time.Time ) error {
1169
+ ctx , cancel := getDebuggingCtx (ctx )
1170
+ defer cancel ()
1171
+
1150
1172
var user types.User
1151
- result := s .db .Take (& user , accountAndIDQueryCondition , accountID , userID )
1173
+ result := s .db .WithContext ( ctx ). Take (& user , accountAndIDQueryCondition , accountID , userID )
1152
1174
if result .Error != nil {
1153
1175
if errors .Is (result .Error , gorm .ErrRecordNotFound ) {
1154
1176
return status .NewUserNotFoundError (userID )
@@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
1329
1351
}
1330
1352
1331
1353
func (s * SqlStore ) GetSetupKeyBySecret (ctx context.Context , lockStrength LockingStrength , key string ) (* types.SetupKey , error ) {
1354
+ ctx , cancel := getDebuggingCtx (ctx )
1355
+ defer cancel ()
1356
+
1332
1357
tx := s .db
1333
1358
if lockStrength != LockingStrengthNone {
1334
1359
tx = tx .Clauses (clause.Locking {Strength : string (lockStrength )})
1335
1360
}
1336
1361
1337
1362
var setupKey types.SetupKey
1338
- result := tx .
1363
+ result := tx .WithContext ( ctx ).
1339
1364
Take (& setupKey , GetKeyQueryCondition (s ), key )
1340
1365
1341
1366
if result .Error != nil {
@@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
1349
1374
}
1350
1375
1351
1376
func (s * SqlStore ) IncrementSetupKeyUsage (ctx context.Context , setupKeyID string ) error {
1352
- result := s .db .Model (& types.SetupKey {}).
1377
+ ctx , cancel := getDebuggingCtx (ctx )
1378
+ defer cancel ()
1379
+
1380
+ result := s .db .WithContext (ctx ).Model (& types.SetupKey {}).
1353
1381
Where (idQueryCondition , setupKeyID ).
1354
1382
Updates (map [string ]interface {}{
1355
1383
"used_times" : gorm .Expr ("used_times + 1" ),
@@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
1369
1397
1370
1398
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
1371
1399
func (s * SqlStore ) AddPeerToAllGroup (ctx context.Context , accountID string , peerID string ) error {
1400
+ ctx , cancel := getDebuggingCtx (ctx )
1401
+ defer cancel ()
1402
+
1372
1403
var groupID string
1373
- _ = s .db .Model (types.Group {}).
1404
+ _ = s .db .WithContext ( ctx ). Model (types.Group {}).
1374
1405
Select ("id" ).
1375
1406
Where ("account_id = ? AND name = ?" , accountID , "All" ).
1376
1407
Limit (1 ).
@@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
1398
1429
1399
1430
// AddPeerToGroup adds a peer to a group
1400
1431
func (s * SqlStore ) AddPeerToGroup (ctx context.Context , accountID , peerID , groupID string ) error {
1432
+ ctx , cancel := getDebuggingCtx (ctx )
1433
+ defer cancel ()
1434
+
1401
1435
peer := & types.GroupPeer {
1402
1436
AccountID : accountID ,
1403
1437
GroupID : groupID ,
1404
1438
PeerID : peerID ,
1405
1439
}
1406
1440
1407
- err := s .db .Clauses (clause.OnConflict {
1441
+ err := s .db .WithContext ( ctx ). Clauses (clause.OnConflict {
1408
1442
Columns : []clause.Column {{Name : "group_id" }, {Name : "peer_id" }},
1409
1443
DoNothing : true ,
1410
1444
}).Create (peer ).Error
@@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
1594
1628
}
1595
1629
1596
1630
func (s * SqlStore ) AddPeerToAccount (ctx context.Context , peer * nbpeer.Peer ) error {
1597
- if err := s .db .Create (peer ).Error ; err != nil {
1631
+ ctx , cancel := getDebuggingCtx (ctx )
1632
+ defer cancel ()
1633
+
1634
+ if err := s .db .WithContext (ctx ).Create (peer ).Error ; err != nil {
1598
1635
return status .Errorf (status .Internal , "issue adding peer to account: %s" , err )
1599
1636
}
1600
1637
@@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
1720
1757
}
1721
1758
1722
1759
func (s * SqlStore ) IncrementNetworkSerial (ctx context.Context , accountId string ) error {
1723
- result := s .db .Model (& types.Account {}).Where (idQueryCondition , accountId ).Update ("network_serial" , gorm .Expr ("network_serial + 1" ))
1760
+ ctx , cancel := getDebuggingCtx (ctx )
1761
+ defer cancel ()
1762
+
1763
+ result := s .db .WithContext (ctx ).Model (& types.Account {}).Where (idQueryCondition , accountId ).Update ("network_serial" , gorm .Expr ("network_serial + 1" ))
1724
1764
if result .Error != nil {
1725
1765
log .WithContext (ctx ).Errorf ("failed to increment network serial count in store: %v" , result .Error )
1726
1766
return status .Errorf (status .Internal , "failed to increment network serial count in store" )
@@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
2762
2802
2763
2803
return groupPeers , nil
2764
2804
}
2805
+
2806
+ func getDebuggingCtx (grpcCtx context.Context ) (context.Context , context.CancelFunc ) {
2807
+ ctx , cancel := context .WithTimeout (context .Background (), time .Minute )
2808
+ userID , ok := grpcCtx .Value (nbcontext .UserIDKey ).(string )
2809
+ if ok {
2810
+ //nolint
2811
+ ctx = context .WithValue (ctx , nbcontext .UserIDKey , userID )
2812
+ }
2813
+
2814
+ requestID , ok := grpcCtx .Value (nbcontext .RequestIDKey ).(string )
2815
+ if ok {
2816
+ //nolint
2817
+ ctx = context .WithValue (ctx , nbcontext .RequestIDKey , requestID )
2818
+ }
2819
+
2820
+ accountID , ok := grpcCtx .Value (nbcontext .AccountIDKey ).(string )
2821
+ if ok {
2822
+ //nolint
2823
+ ctx = context .WithValue (ctx , nbcontext .AccountIDKey , accountID )
2824
+ }
2825
+
2826
+ go func () {
2827
+ select {
2828
+ case <- ctx .Done ():
2829
+ case <- grpcCtx .Done ():
2830
+ log .WithContext (grpcCtx ).Warnf ("grpc context ended early, error: %v" , grpcCtx .Err ())
2831
+ }
2832
+ }()
2833
+ return ctx , cancel
2834
+ }
0 commit comments