Skip to content

Commit 98ef89d

Browse files
committed
refactor: make interface more resource type independent
1 parent bf99482 commit 98ef89d

File tree

9 files changed

+220
-194
lines changed

9 files changed

+220
-194
lines changed

backend/internal/api/handlers/deployment_handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) {
241241
wfUUID, wfStatus, err := h.svc.AsyncDeployCluster(config, cluster)
242242
if err != nil {
243243
reqLog.Error().Err(err).Msg("failed to start deployment workflow")
244-
if errors.Is(err, distributedlocks.ErrNodeLocked) {
244+
if errors.Is(err, distributedlocks.ErrResourceLocked) {
245245
Conflict(c, "Node is busy serving another request")
246246
return
247247
}
@@ -417,7 +417,7 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) {
417417
wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0])
418418
if err != nil {
419419
reqLog.Error().Err(err).Msg("failed to start add node workflow")
420-
if errors.Is(err, distributedlocks.ErrNodeLocked) {
420+
if errors.Is(err, distributedlocks.ErrResourceLocked) {
421421
Conflict(c, "Node is busy serving another request")
422422
return
423423
}

backend/internal/api/handlers/node_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) {
298298
wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID)
299299
if err != nil {
300300
reqLog.Error().Err(err).Msg("failed to start workflow to reserve node")
301-
if errors.Is(err, distributedlocks.ErrNodeLocked) {
301+
if errors.Is(err, distributedlocks.ErrResourceLocked) {
302302
Conflict(c, "Node is busy serving another request")
303303
return
304304
}

backend/internal/core/distributed_locks/distributed_locks.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ import (
55
"errors"
66
)
77

8-
var ErrNodeLocked = errors.New("node is currently locked by another request")
8+
var ErrResourceLocked = errors.New("resource is currently locked by another request")
9+
10+
const (
11+
NodeLockPrefix = "node:"
12+
)
913

10-
// DistributedLocks is an interface that defines the methods for distributed locks.
1114
type DistributedLocks interface {
12-
AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error)
13-
ReleaseLock(ctx context.Context, lockedKeys map[string]string) error
14-
GetLockedNodes(ctx context.Context) ([]uint32, error)
15+
AcquireLocks(ctx context.Context, resourceKeys []string) (map[string]string, error)
16+
ReleaseLocks(ctx context.Context, lockedKeys map[string]string) error
17+
GetLockedResources(ctx context.Context, keyPattern string) ([]string, error)
1518
}

backend/internal/core/distributed_locks/redis_locker.go

Lines changed: 92 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,12 @@ package distributedlocks
33
import (
44
"context"
55
"fmt"
6-
"strconv"
7-
"strings"
86
"time"
97

108
"github.com/google/uuid"
119
"github.com/redis/go-redis/v9"
1210
)
1311

14-
const (
15-
nodeLockKey = "locked"
16-
)
17-
1812
type RedisLocker struct {
1913
client *redis.Client
2014
lockTimeout time.Duration
@@ -28,116 +22,128 @@ func NewRedisLocker(client *redis.Client, lockTimeout time.Duration) *RedisLocke
2822
}
2923
}
3024

31-
// AcquireNodesLocks acquires locks for the given node IDs.
32-
func (l *RedisLocker) AcquireNodesLocks(ctx context.Context, nodeIDs []uint32) (map[string]string, error) {
33-
lockedKeys, err := l.acquireKeys(ctx, nodeLockKeys(nodeIDs))
34-
if err != nil {
35-
return nil, err
25+
func (l *RedisLocker) AcquireLocks(ctx context.Context, resourceKeys []string) (map[string]string, error) {
26+
if len(resourceKeys) == 0 {
27+
return nil, fmt.Errorf("no resource keys provided")
3628
}
37-
return lockedKeys, nil
38-
}
3929

40-
func nodeLockKeys(nodeIDs []uint32) []string {
41-
keys := make([]string, len(nodeIDs))
42-
for i, id := range nodeIDs {
43-
keys[i] = fmt.Sprintf("%s:%d", nodeLockKey, id)
44-
}
45-
return keys
46-
}
30+
expiry := int64(l.lockTimeout / time.Millisecond)
4731

48-
func (l *RedisLocker) acquireKeys(ctx context.Context, keys []string) (map[string]string, error) {
49-
locked := make(map[string]string, len(keys))
50-
51-
for _, key := range keys {
52-
keyValue := uuid.New().String()
53-
ok, err := l.client.SetNX(ctx, key, keyValue, l.lockTimeout).Result()
54-
if err != nil {
55-
if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil {
56-
return nil, rollErr
57-
}
58-
return nil, fmt.Errorf("redis error while acquiring lock for key %s: %w", key, err)
59-
}
32+
values := make([]string, len(resourceKeys))
33+
argv := make([]interface{}, 0, len(resourceKeys)+1)
34+
//expiry of locks
35+
argv = append(argv, expiry)
6036

61-
if !ok {
62-
if rollErr := l.rollbackLocks(ctx, locked); rollErr != nil {
63-
return nil, rollErr
64-
}
65-
return nil, fmt.Errorf("%w: %s", ErrNodeLocked, key)
66-
}
67-
68-
locked[key] = keyValue
37+
// uuid values for each key
38+
for i := range resourceKeys {
39+
val := uuid.New().String()
40+
values[i] = val
41+
argv = append(argv, val)
6942
}
7043

71-
return locked, nil
72-
}
44+
lua := redis.NewScript(`
45+
local expiry = tonumber(ARGV[1])
46+
local locked = {}
47+
48+
for i = 1, #KEYS do
49+
local ok = redis.call("SET", KEYS[i], ARGV[i+1], "PX", expiry, "NX")
50+
if not ok then
51+
for j = 1, #locked do
52+
redis.call("DEL", KEYS[j])
53+
end
54+
return {"LOCKED", KEYS[i]}
55+
end
56+
table.insert(locked, KEYS[i])
57+
end
58+
59+
return {"OK"}
60+
`)
61+
62+
res, err := lua.Run(ctx, l.client, resourceKeys, argv...).Result()
63+
if err != nil {
64+
return nil, err
65+
}
7366

74-
func (l *RedisLocker) rollbackLocks(ctx context.Context, locked map[string]string) error {
75-
if len(locked) == 0 {
76-
return nil
67+
out, ok := res.([]interface{})
68+
if !ok || len(out) == 0 {
69+
return nil, fmt.Errorf("unexpected script output: %v", res)
7770
}
78-
keys := make([]string, 0, len(locked))
79-
for k := range locked {
80-
keys = append(keys, k)
71+
72+
status, _ := out[0].(string)
73+
if status == "LOCKED" {
74+
conflict := out[1].(string)
75+
return nil, fmt.Errorf("%w: %s", ErrResourceLocked, conflict)
8176
}
8277

83-
if err := l.client.Del(ctx, keys...).Err(); err != nil {
84-
return fmt.Errorf("redis error while rolling back locks: %w", err)
78+
locked := map[string]string{}
79+
for i, k := range resourceKeys {
80+
locked[k] = values[i]
8581
}
8682

87-
return nil
83+
return locked, nil
8884
}
8985

90-
// ReleaseLock releases the locks for the given keys.
91-
func (l *RedisLocker) ReleaseLock(ctx context.Context, lockedKeys map[string]string) error {
86+
// ReleaseLocks releases the locks for the given keys.
87+
func (l *RedisLocker) ReleaseLocks(ctx context.Context, lockedKeys map[string]string) error {
9288
if len(lockedKeys) == 0 {
9389
return nil
9490
}
91+
keys := make([]string, 0, len(lockedKeys))
92+
values := make([]interface{}, 0, len(lockedKeys))
9593

96-
var failedKeys []string
97-
for key, expectedValue := range lockedKeys {
98-
storedValue, err := l.client.Get(ctx, key).Result()
99-
if err == redis.Nil {
100-
continue
101-
}
102-
if err != nil {
103-
return fmt.Errorf("failed to get lock value for key %s: %w", key, err)
104-
}
105-
106-
if storedValue != expectedValue {
107-
failedKeys = append(failedKeys, key)
108-
continue
109-
}
94+
for k, v := range lockedKeys {
95+
keys = append(keys, k)
96+
values = append(values, v)
97+
}
11098

111-
if err := l.client.Del(ctx, key).Err(); err != nil {
112-
return fmt.Errorf("failed to delete lock for key %s: %w", key, err)
113-
}
99+
luaScript := redis.NewScript(`
100+
local failed = {}
101+
for i = 1, #KEYS do
102+
local key = KEYS[i]
103+
local expected = ARGV[i]
104+
local actual = redis.call("GET", key)
105+
106+
if actual ~= false then
107+
if actual ~= expected then
108+
table.insert(failed, key)
109+
else
110+
redis.call("DEL", key)
111+
end
112+
end
113+
end
114+
return failed
115+
`)
116+
117+
// Run the script
118+
res, err := luaScript.Run(ctx, l.client, keys, values...).Result()
119+
if err != nil {
120+
return err
114121
}
115122

123+
failedKeys, _ := res.([]interface{})
116124
if len(failedKeys) > 0 {
117-
return fmt.Errorf("lock value mismatch for keys: %v", failedKeys)
125+
mismatches := make([]string, len(failedKeys))
126+
for i, v := range failedKeys {
127+
mismatches[i] = v.(string)
128+
}
129+
return fmt.Errorf("lock value mismatch for keys: %v", mismatches)
118130
}
119131

120132
return nil
121133
}
122134

123-
// GetLockedNodes returns the list of locked nodes.
124-
func (l *RedisLocker) GetLockedNodes(ctx context.Context) ([]uint32, error) {
125-
iter := l.client.Scan(ctx, 0, "locked:*", 0).Iterator()
126-
127-
nodes := make([]uint32, 0)
135+
// GetLockedResources returns all currently locked resource keys matching the given pattern.
136+
func (l *RedisLocker) GetLockedResources(ctx context.Context, keyPattern string) ([]string, error) {
137+
if keyPattern == "" {
138+
keyPattern = "*"
139+
}
140+
iter := l.client.Scan(ctx, 0, keyPattern, 0).Iterator()
141+
resources := make([]string, 0)
128142
for iter.Next(ctx) {
129-
key := iter.Val()
130-
nodeID := strings.Split(key, ":")[1]
131-
value, parseErr := strconv.ParseUint(nodeID, 10, 32)
132-
if parseErr != nil {
133-
return nil, fmt.Errorf("failed to parse locked node id from %s: %w", key, parseErr)
134-
}
135-
nodes = append(nodes, uint32(value))
143+
resources = append(resources, iter.Val())
136144
}
137-
138145
if err := iter.Err(); err != nil {
139146
return nil, err
140147
}
141-
142-
return nodes, nil
148+
return resources, nil
143149
}

0 commit comments

Comments
 (0)