@@ -3,18 +3,12 @@ package distributedlocks
33import (
44 "context"
55 "fmt"
6- "strconv"
7- "strings"
86 "time"
97
108 "github.com/google/uuid"
119 "github.com/redis/go-redis/v9"
1210)
1311
14- const (
15- nodeLockKey = "locked"
16- )
17-
1812type RedisLocker struct {
1913 client * redis.Client
2014 lockTimeout time.Duration
@@ -28,116 +22,128 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke
2822 }
2923}
3024
31- // AcquireNodesLocks acquires locks for the given node IDs.
32- func (l * RedisLocker ) AcquireNodesLocks (ctx context.Context , nodeIDs []uint32 ) (map [string ]string , error ) {
33- lockedKeys , err := l .acquireKeys (ctx , nodeLockKeys (nodeIDs ))
34- if err != nil {
35- return nil , err
25+ func (l * RedisLocker ) AcquireLocks (ctx context.Context , resourceKeys []string ) (map [string ]string , error ) {
26+ if len (resourceKeys ) == 0 {
27+ return nil , fmt .Errorf ("no resource keys provided" )
3628 }
37- return lockedKeys , nil
38- }
3929
40- func nodeLockKeys (nodeIDs []uint32 ) []string {
41- keys := make ([]string , len (nodeIDs ))
42- for i , id := range nodeIDs {
43- keys [i ] = fmt .Sprintf ("%s:%d" , nodeLockKey , id )
44- }
45- return keys
46- }
30+ expiry := int64 (l .lockTimeout / time .Millisecond )
4731
48- func (l * RedisLocker ) acquireKeys (ctx context.Context , keys []string ) (map [string ]string , error ) {
49- locked := make (map [string ]string , len (keys ))
50-
51- for _ , key := range keys {
52- keyValue := uuid .New ().String ()
53- ok , err := l .client .SetNX (ctx , key , keyValue , l .lockTimeout ).Result ()
54- if err != nil {
55- if rollErr := l .rollbackLocks (ctx , locked ); rollErr != nil {
56- return nil , rollErr
57- }
58- return nil , fmt .Errorf ("redis error while acquiring lock for key %s: %w" , key , err )
59- }
32+ values := make ([]string , len (resourceKeys ))
33+ argv := make ([]interface {}, 0 , len (resourceKeys )+ 1 )
34+ //expiry of locks
35+ argv = append (argv , expiry )
6036
61- if ! ok {
62- if rollErr := l .rollbackLocks (ctx , locked ); rollErr != nil {
63- return nil , rollErr
64- }
65- return nil , fmt .Errorf ("%w: %s" , ErrNodeLocked , key )
66- }
67-
68- locked [key ] = keyValue
37+ // uuid values for each key
38+ for i := range resourceKeys {
39+ val := uuid .New ().String ()
40+ values [i ] = val
41+ argv = append (argv , val )
6942 }
7043
71- return locked , nil
72- }
44+ lua := redis .NewScript (`
45+ local expiry = tonumber(ARGV[1])
46+ local locked = {}
47+
48+ for i = 1, #KEYS do
49+ local ok = redis.call("SET", KEYS[i], ARGV[i+1], "PX", expiry, "NX")
50+ if not ok then
51+ for j = 1, #locked do
52+ redis.call("DEL", KEYS[j])
53+ end
54+ return {"LOCKED", KEYS[i]}
55+ end
56+ table.insert(locked, KEYS[i])
57+ end
58+
59+ return {"OK"}
60+ ` )
61+
62+ res , err := lua .Run (ctx , l .client , resourceKeys , argv ... ).Result ()
63+ if err != nil {
64+ return nil , err
65+ }
7366
74- func ( l * RedisLocker ) rollbackLocks ( ctx context. Context , locked map [ string ] string ) error {
75- if len (locked ) == 0 {
76- return nil
67+ out , ok := res .([] interface {})
68+ if ! ok || len (out ) == 0 {
69+ return nil , fmt . Errorf ( "unexpected script output: %v" , res )
7770 }
78- keys := make ([]string , 0 , len (locked ))
79- for k := range locked {
80- keys = append (keys , k )
71+
72+ status , _ := out [0 ].(string )
73+ if status == "LOCKED" {
74+ conflict := out [1 ].(string )
75+ return nil , fmt .Errorf ("%w: %s" , ErrResourceLocked , conflict )
8176 }
8277
83- if err := l .client .Del (ctx , keys ... ).Err (); err != nil {
84- return fmt .Errorf ("redis error while rolling back locks: %w" , err )
78+ locked := map [string ]string {}
79+ for i , k := range resourceKeys {
80+ locked [k ] = values [i ]
8581 }
8682
87- return nil
83+ return locked , nil
8884}
8985
90- // ReleaseLock releases the locks for the given keys.
91- func (l * RedisLocker ) ReleaseLock (ctx context.Context , lockedKeys map [string ]string ) error {
86+ // ReleaseLocks releases the locks for the given keys.
87+ func (l * RedisLocker ) ReleaseLocks (ctx context.Context , lockedKeys map [string ]string ) error {
9288 if len (lockedKeys ) == 0 {
9389 return nil
9490 }
91+ keys := make ([]string , 0 , len (lockedKeys ))
92+ values := make ([]interface {}, 0 , len (lockedKeys ))
9593
96- var failedKeys []string
97- for key , expectedValue := range lockedKeys {
98- storedValue , err := l .client .Get (ctx , key ).Result ()
99- if err == redis .Nil {
100- continue
101- }
102- if err != nil {
103- return fmt .Errorf ("failed to get lock value for key %s: %w" , key , err )
104- }
105-
106- if storedValue != expectedValue {
107- failedKeys = append (failedKeys , key )
108- continue
109- }
94+ for k , v := range lockedKeys {
95+ keys = append (keys , k )
96+ values = append (values , v )
97+ }
11098
111- if err := l .client .Del (ctx , key ).Err (); err != nil {
112- return fmt .Errorf ("failed to delete lock for key %s: %w" , key , err )
113- }
99+ luaScript := redis .NewScript (`
100+ local failed = {}
101+ for i = 1, #KEYS do
102+ local key = KEYS[i]
103+ local expected = ARGV[i]
104+ local actual = redis.call("GET", key)
105+
106+ if actual ~= false then
107+ if actual ~= expected then
108+ table.insert(failed, key)
109+ else
110+ redis.call("DEL", key)
111+ end
112+ end
113+ end
114+ return failed
115+ ` )
116+
117+ // Run the script
118+ res , err := luaScript .Run (ctx , l .client , keys , values ... ).Result ()
119+ if err != nil {
120+ return err
114121 }
115122
123+ failedKeys , _ := res .([]interface {})
116124 if len (failedKeys ) > 0 {
117- return fmt .Errorf ("lock value mismatch for keys: %v" , failedKeys )
125+ mismatches := make ([]string , len (failedKeys ))
126+ for i , v := range failedKeys {
127+ mismatches [i ] = v .(string )
128+ }
129+ return fmt .Errorf ("lock value mismatch for keys: %v" , mismatches )
118130 }
119131
120132 return nil
121133}
122134
123- // GetLockedNodes returns the list of locked nodes.
124- func (l * RedisLocker ) GetLockedNodes (ctx context.Context ) ([]uint32 , error ) {
125- iter := l .client .Scan (ctx , 0 , "locked:*" , 0 ).Iterator ()
126-
127- nodes := make ([]uint32 , 0 )
135+ // GetLockedResources returns all currently locked resource keys matching the given pattern.
136+ func (l * RedisLocker ) GetLockedResources (ctx context.Context , keyPattern string ) ([]string , error ) {
137+ if keyPattern == "" {
138+ keyPattern = "*"
139+ }
140+ iter := l .client .Scan (ctx , 0 , keyPattern , 0 ).Iterator ()
141+ resources := make ([]string , 0 )
128142 for iter .Next (ctx ) {
129- key := iter .Val ()
130- nodeID := strings .Split (key , ":" )[1 ]
131- value , parseErr := strconv .ParseUint (nodeID , 10 , 32 )
132- if parseErr != nil {
133- return nil , fmt .Errorf ("failed to parse locked node id from %s: %w" , key , parseErr )
134- }
135- nodes = append (nodes , uint32 (value ))
143+ resources = append (resources , iter .Val ())
136144 }
137-
138145 if err := iter .Err (); err != nil {
139146 return nil , err
140147 }
141-
142- return nodes , nil
148+ return resources , nil
143149}
0 commit comments