Skip to content

Commit 9c3ea0a

Browse files
committed
Move redis task completion to script
1 parent 34a5c67 commit 9c3ea0a

File tree

12 files changed

+417
-168
lines changed

12 files changed

+417
-168
lines changed

backend/redis/delete.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ import (
1414
// KEYS[4] - payload key
1515
// KEYS[5] - active-instance-execution key
1616
// KEYS[6] - instances-by-creation key
17+
// KEYS[7] - instances
1718
// ARGV[1] - instance segment
19+
// ARGV[2] - instance id
1820
var deleteCmd = redis.NewScript(
1921
`redis.call("DEL", KEYS[1], KEYS[2], KEYS[3], KEYS[4], KEYS[5])
22+
redis.call("HDEL", KEYS[7], ARGV[1])
2023
return redis.call("ZREM", KEYS[6], ARGV[1])`)
2124

2225
// deleteInstance deletes an instance from Redis. It does not attempt to remove any future events or pending
@@ -31,7 +34,8 @@ func deleteInstance(ctx context.Context, rdb redis.UniversalClient, instance *co
3134
payloadKey(instance),
3235
activeInstanceExecutionKey(instance.InstanceID),
3336
instancesByCreation(),
34-
}, instanceSegment(instance)).Err(); err != nil {
37+
instanceIDs(),
38+
}, instanceSegment(instance), instance.InstanceID).Err(); err != nil {
3539
return fmt.Errorf("failed to delete instance: %w", err)
3640
}
3741

backend/redis/events.go

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -72,34 +72,3 @@ func addEventToStreamP(ctx context.Context, p redis.Pipeliner, streamKey string,
7272
},
7373
}).Err()
7474
}
75-
76-
// addEventsToStream adds the given events to the given event stream. If successful, the message id of the last event added
77-
// is returned
78-
// KEYS[1] - stream key
79-
// ARGV[1] - event data as serialized strings
80-
var addEventsToStreamCmd = redis.NewScript(`
81-
local msgID = ""
82-
for i = 1, #ARGV, 2 do
83-
msgID = redis.call("XADD", KEYS[1], ARGV[i], "event", ARGV[i + 1])
84-
end
85-
return msgID
86-
`)
87-
88-
func addEventsToStreamP(ctx context.Context, p redis.Pipeliner, streamKey string, events []*history.Event) error {
89-
eventsData := make([]string, 0)
90-
for _, event := range events {
91-
eventData, err := marshalEventWithoutAttributes(event)
92-
if err != nil {
93-
return err
94-
}
95-
96-
// log.Println("addEventsToHistoryStreamP:", event.SequenceID, string(eventData))
97-
98-
eventsData = append(eventsData, historyID(event.SequenceID))
99-
eventsData = append(eventsData, string(eventData))
100-
}
101-
102-
addEventsToStreamCmd.Run(ctx, p, []string{streamKey}, eventsData)
103-
104-
return nil
105-
}

backend/redis/expire.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@ import (
1818
// KEYS[4] - pending events key
1919
// KEYS[5] - history key
2020
// KEYS[6] - payload key
21+
// KEYS[7] - instances key
2122
// ARGV[1] - current timestamp
2223
// ARGV[2] - expiration time in seconds
2324
// ARGV[3] - expiration timestamp in unix milliseconds
2425
// ARGV[4] - instance segment
26+
// ARGV[5] - instance id
2527
var expireCmd = redis.NewScript(
2628
`-- Find instances which have already expired and remove from the index set
2729
local expiredInstances = redis.call("ZRANGE", KEYS[2], "-inf", ARGV[1], "BYSCORE")
2830
for i = 1, #expiredInstances do
2931
local instanceSegment = expiredInstances[i]
3032
redis.call("ZREM", KEYS[1], instanceSegment) -- index set
3133
redis.call("ZREM", KEYS[2], instanceSegment) -- expiration set
34+
redis.call("HDEL", KEYS[7], ARGV[5])
3235
end
3336
3437
-- Add expiration time for future cleanup
@@ -57,10 +60,12 @@ func setWorkflowInstanceExpiration(ctx context.Context, rdb redis.UniversalClien
5760
pendingEventsKey(instance),
5861
historyKey(instance),
5962
payloadKey(instance),
63+
instanceIDs(),
6064
},
6165
nowStr,
6266
expiration.Seconds(),
6367
expStr,
6468
instanceSegment(instance),
69+
instance.InstanceID,
6570
).Err()
6671
}

backend/redis/instance.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (rb *redisBackend) CreateWorkflowInstance(ctx context.Context, instance *wo
2626

2727
p := rb.rdb.TxPipeline()
2828

29-
if err := createInstanceP(ctx, p, instance, event.Attributes.(*history.ExecutionStartedAttributes).Metadata, false); err != nil {
29+
if err := createInstanceP(ctx, p, instance, event.Attributes.(*history.ExecutionStartedAttributes).Metadata); err != nil {
3030
return err
3131
}
3232

@@ -145,7 +145,7 @@ type instanceState struct {
145145
LastSequenceID int64 `json:"last_sequence_id,omitempty"`
146146
}
147147

148-
func createInstanceP(ctx context.Context, p redis.Pipeliner, instance *core.WorkflowInstance, metadata *metadata.WorkflowMetadata, ignoreDuplicate bool) error {
148+
func createInstanceP(ctx context.Context, p redis.Pipeliner, instance *core.WorkflowInstance, metadata *metadata.WorkflowMetadata) error {
149149
key := instanceKey(instance)
150150

151151
createdAt := time.Now()
@@ -165,6 +165,9 @@ func createInstanceP(ctx context.Context, p redis.Pipeliner, instance *core.Work
165165
// The newly created instance is going to be the active execution
166166
setActiveInstanceExecutionP(ctx, p, instance)
167167

168+
// Record instance id
169+
p.HSet(ctx, instanceIDs(), instance.InstanceID, 1)
170+
168171
p.ZAdd(ctx, instancesByCreation(), redis.Z{
169172
Member: instanceSegment(instance),
170173
Score: float64(createdAt.UnixMilli()),

backend/redis/keys.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ func instancesExpiring() string {
3737
return "instances-expiring"
3838
}
3939

40+
func instanceIDs() string {
41+
return "instances"
42+
}
43+
4044
func pendingEventsKey(instance *core.WorkflowInstance) string {
4145
return fmt.Sprintf("pending-events:%v", instanceSegment(instance))
4246
}

backend/redis/queue.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"log"
78
"time"
89

910
"github.com/google/uuid"
@@ -119,7 +120,6 @@ var createGroupCmd = redis.NewScript(`
119120
return true
120121
`)
121122

122-
123123
func (q *taskQueue[T]) Enqueue(ctx context.Context, p redis.Pipeliner, id string, data *T) error {
124124
ds, err := json.Marshal(data)
125125
if err != nil {
@@ -139,6 +139,7 @@ func (q *taskQueue[T]) Dequeue(ctx context.Context, rdb redis.UniversalClient, l
139139
}
140140

141141
if task != nil {
142+
log.Println("Recovered task", task.ID)
142143
return task, nil
143144
}
144145

backend/redis/redis.go

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package redis
22

33
import (
44
"context"
5+
"embed"
56
"fmt"
7+
"io/fs"
68
"log/slog"
79
"time"
810

@@ -19,6 +21,11 @@ import (
1921

2022
var _ backend.Backend = (*redisBackend)(nil)
2123

24+
//go:embed scripts/*.lua
25+
var luaScripts embed.FS
26+
27+
var completeWorkflowTaskCmd *redis.Script
28+
2229
func NewRedisBackend(client redis.UniversalClient, opts ...RedisBackendOption) (*redisBackend, error) {
2330
workflowQueue, err := newTaskQueue[any](client, "workflows")
2431
if err != nil {
@@ -52,15 +59,13 @@ func NewRedisBackend(client redis.UniversalClient, opts ...RedisBackendOption) (
5259
// them, loads them. This doesn't work when using (transactional) pipelines, so eagerly load them on startup.
5360
ctx := context.Background()
5461
cmds := map[string]*redis.StringCmd{
55-
"addEventsToStreamCmd": addEventsToStreamCmd.Load(ctx, rb.rdb),
56-
"addFutureEventCmd": addFutureEventCmd.Load(ctx, rb.rdb),
57-
"futureEventsCmd": futureEventsCmd.Load(ctx, rb.rdb),
58-
"removeFutureEventCmd": removeFutureEventCmd.Load(ctx, rb.rdb),
59-
"removePendingEventsCmd": removePendingEventsCmd.Load(ctx, rb.rdb),
60-
"requeueInstanceCmd": requeueInstanceCmd.Load(ctx, rb.rdb),
61-
"deleteInstanceCmd": deleteCmd.Load(ctx, rb.rdb),
62-
"expireInstanceCmd": expireCmd.Load(ctx, rb.rdb),
63-
"addPayloadsCmd": addPayloadsCmd.Load(ctx, rb.rdb),
62+
"addEventsToStreamCmd": addEventsToStreamCmd.Load(ctx, rb.rdb),
63+
"addFutureEventCmd": addFutureEventCmd.Load(ctx, rb.rdb),
64+
"futureEventsCmd": futureEventsCmd.Load(ctx, rb.rdb),
65+
"removeFutureEventCmd": removeFutureEventCmd.Load(ctx, rb.rdb),
66+
"deleteInstanceCmd": deleteCmd.Load(ctx, rb.rdb),
67+
"expireInstanceCmd": expireCmd.Load(ctx, rb.rdb),
68+
"addPayloadsCmd": addPayloadsCmd.Load(ctx, rb.rdb),
6469
}
6570
for name, cmd := range cmds {
6671
// fmt.Println(name, cmd.Val())
@@ -70,6 +75,25 @@ func NewRedisBackend(client redis.UniversalClient, opts ...RedisBackendOption) (
7075
}
7176
}
7277

78+
// Load all Lua scripts
79+
80+
cmdMapping := map[string]**redis.Script{
81+
"complete_workflow_task.lua": &completeWorkflowTaskCmd,
82+
}
83+
84+
for scriptFile, cmd := range cmdMapping {
85+
scriptContent, err := fs.ReadFile(luaScripts, "scripts/"+scriptFile)
86+
if err != nil {
87+
return nil, fmt.Errorf("reading Lua script %s: %w", scriptFile, err)
88+
}
89+
90+
*cmd = redis.NewScript(string(scriptContent))
91+
92+
if c := (*cmd).Load(ctx, rb.rdb); c.Err() != nil {
93+
return nil, fmt.Errorf("loading Lua script %s: %w", scriptFile, c.Err())
94+
}
95+
}
96+
7397
return rb, nil
7498
}
7599

0 commit comments

Comments
 (0)