@@ -7,9 +7,14 @@ import (
77 "strings"
88 "time"
99
10+ "github.com/google/uuid"
1011 "github.com/redis/go-redis/v9"
1112)
1213
14+ const (
15+ nodeLockKey = "locked"
16+ )
17+
1318type RedisLocker struct {
1419 client * redis.Client
1520 lockTimeout time.Duration
@@ -24,77 +29,56 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke
2429}
2530
2631// AcquireNodesLocks acquires locks for the given node IDs.
27- func (l * RedisLocker ) AcquireNodesLocks (ctx context.Context , nodeIDs []uint32 ) error {
28- if err := l .acquireKeys (ctx , lockKeys (nodeIDs , nodeLockKey )); err != nil {
29- return err
30- }
31-
32- return nil
33- }
34-
35- // AcquireWorkflowLock acquires a lock for the given workflow ID.
36- func (l * RedisLocker ) AcquireWorkflowLock (ctx context.Context , nodeIDs []uint32 , workflowID string ) error {
37- keys := lockKeys (nodeIDs , func (id uint32 ) string {
38- return workflowLockKey (id , workflowID )
39- })
40-
41- if err := l .acquireKeys (ctx , keys ); err != nil {
42- //rollback nodes locks
43- nodeLockKeys := lockKeys (nodeIDs , nodeLockKey )
44- if rollErr := l .rollbackLocks (ctx , nodeLockKeys ); rollErr != nil {
45- return rollErr
46- }
47- return err
32+ func (l * RedisLocker ) AcquireNodesLocks (ctx context.Context , nodeIDs []uint32 ) (map [string ]string , error ) {
33+ lockedKeys , err := l .acquireKeys (ctx , nodeLockKeys (nodeIDs ))
34+ if err != nil {
35+ return nil , err
4836 }
49-
50- return nil
51- }
52-
53- func nodeLockKey (nodeID uint32 ) string {
54- return fmt .Sprintf ("locked:%d" , nodeID )
55- }
56-
57- func workflowLockKey (nodeID uint32 , workflowID string ) string {
58- return fmt .Sprintf ("used:%d:%s" , nodeID , workflowID )
37+ return lockedKeys , nil
5938}
6039
61- func lockKeys ( ids []uint32 , keyFunc func ( uint32 ) string ) []string {
62- keys := make ([]string , len (ids ))
63- for i , id := range ids {
64- keys [i ] = keyFunc ( id )
40+ func nodeLockKeys ( nodeIDs []uint32 ) []string {
41+ keys := make ([]string , len (nodeIDs ))
42+ for i , id := range nodeIDs {
43+ keys [i ] = fmt . Sprintf ( "%s:%d" , nodeLockKey , id )
6544 }
6645 return keys
6746}
6847
69- func (l * RedisLocker ) acquireKeys (ctx context.Context , keys []string ) error {
70- locked := make ([ ]string , 0 , len (keys ))
48+ func (l * RedisLocker ) acquireKeys (ctx context.Context , keys []string ) ( map [ string ] string , error ) {
49+ locked := make (map [ string ]string , len (keys ))
7150
7251 for _ , key := range keys {
73- ok , err := l .client .SetNX (ctx , key , 1 , l .lockTimeout ).Result ()
52+ keyValue := uuid .New ().String ()
53+ ok , err := l .client .SetNX (ctx , key , keyValue , l .lockTimeout ).Result ()
7454 if err != nil {
7555 if rollErr := l .rollbackLocks (ctx , locked ); rollErr != nil {
76- return rollErr
56+ return nil , rollErr
7757 }
78- return fmt .Errorf ("redis error while acquiring lock for key %s: %w" , key , err )
58+ return nil , fmt .Errorf ("redis error while acquiring lock for key %s: %w" , key , err )
7959 }
8060
8161 if ! ok {
8262 if rollErr := l .rollbackLocks (ctx , locked ); rollErr != nil {
83- return rollErr
63+ return nil , rollErr
8464 }
85- return fmt .Errorf ("%w: %s" , ErrNodeLocked , key )
65+ return nil , fmt .Errorf ("%w: %s" , ErrNodeLocked , key )
8666 }
8767
88- locked = append ( locked , key )
68+ locked [ key ] = keyValue
8969 }
9070
91- return nil
71+ return locked , nil
9272}
9373
94- func (l * RedisLocker ) rollbackLocks (ctx context.Context , keys [ ]string ) error {
95- if len (keys ) == 0 {
74+ func (l * RedisLocker ) rollbackLocks (ctx context.Context , locked map [ string ]string ) error {
75+ if len (locked ) == 0 {
9676 return nil
9777 }
78+ keys := make ([]string , 0 , len (locked ))
79+ for k := range locked {
80+ keys = append (keys , k )
81+ }
9882
9983 if err := l .client .Del (ctx , keys ... ).Err (); err != nil {
10084 return fmt .Errorf ("redis error while rolling back locks: %w" , err )
@@ -103,18 +87,36 @@ func (l *RedisLocker) rollbackLocks(ctx context.Context, keys []string) error {
10387 return nil
10488}
10589
106- func (l * RedisLocker ) ReleaseLock (ctx context.Context , nodeIDs []uint32 , workflowID string ) error {
107- lockedKeys := lockKeys (nodeIDs , nodeLockKey )
108- usedKeys := lockKeys (nodeIDs , func (id uint32 ) string {
109- return workflowLockKey (id , workflowID )
110- })
111- allWorkflowsLocks := append (lockedKeys , usedKeys ... )
112- return l .client .Del (ctx , allWorkflowsLocks ... ).Err ()
113- }
90+ func (l * RedisLocker ) ReleaseLock (ctx context.Context , lockedKeys map [string ]string ) error {
91+ if len (lockedKeys ) == 0 {
92+ return nil
93+ }
94+
95+ var failedKeys []string
96+ for key , expectedValue := range lockedKeys {
97+ storedValue , err := l .client .Get (ctx , key ).Result ()
98+ if err == redis .Nil {
99+ continue
100+ }
101+ if err != nil {
102+ return fmt .Errorf ("failed to get lock value for key %s: %w" , key , err )
103+ }
114104
115- // GetAllWorkflowsLocks gets all workflow locks.
116- func (l * RedisLocker ) GetAllWorkflowsLocks (ctx context.Context ) ([]string , error ) {
117- return l .client .Keys (ctx , "used:*" ).Result ()
105+ if storedValue != expectedValue {
106+ failedKeys = append (failedKeys , key )
107+ continue
108+ }
109+
110+ if err := l .client .Del (ctx , key ).Err (); err != nil {
111+ return fmt .Errorf ("failed to delete lock for key %s: %w" , key , err )
112+ }
113+ }
114+
115+ if len (failedKeys ) > 0 {
116+ return fmt .Errorf ("lock value mismatch for keys: %v" , failedKeys )
117+ }
118+
119+ return nil
118120}
119121
120122func (l * RedisLocker ) GetLockedNodes (ctx context.Context ) ([]uint32 , error ) {
0 commit comments