Skip to content

Commit 2a6003a

Browse files
authored
Merge pull request #387 from cschleiden/copilot/fix-386
Allow setting custom worker names for SQL backends
2 parents 7973e53 + 4a7489d commit 2a6003a

File tree

8 files changed

+130
-13
lines changed

8 files changed

+130
-13
lines changed

backend/mysql/mysql.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func NewMysqlBackend(host string, port int, user, password, database string, opt
5555
b := &mysqlBackend{
5656
dsn: dsn,
5757
db: db,
58-
workerName: fmt.Sprintf("worker-%v", uuid.NewString()),
58+
workerName: getWorkerName(options),
5959
options: options,
6060
}
6161

@@ -974,3 +974,11 @@ func scheduleActivity(ctx context.Context, tx *sql.Tx, queue workflow.Queue, ins
974974

975975
return err
976976
}
977+
978+
// getWorkerName returns the worker name from options, or generates a UUID-based name if not set.
979+
func getWorkerName(options *options) string {
980+
if options.Options.WorkerName != "" {
981+
return options.Options.WorkerName
982+
}
983+
return fmt.Sprintf("worker-%v", uuid.NewString())
984+
}

backend/mysql/mysql_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,53 @@ func (mb *mysqlBackend) GetFutureEvents(ctx context.Context) ([]*history.Event,
162162

163163
return f, nil
164164
}
165+
166+
func Test_MysqlBackend_WorkerName(t *testing.T) {
167+
if testing.Short() {
168+
t.Skip()
169+
}
170+
171+
t.Run("DefaultWorkerName", func(t *testing.T) {
172+
// Create a backend without specifying worker name
173+
// Since we can't connect to MySQL without it being available, we'll test the getWorkerName function directly
174+
options := &options{
175+
Options: backend.ApplyOptions(),
176+
}
177+
workerName := getWorkerName(options)
178+
179+
// The default worker name should be in the format "worker-<uuid>"
180+
if !strings.Contains(workerName, "worker-") {
181+
t.Errorf("Expected worker name to contain 'worker-', got: %s", workerName)
182+
}
183+
if len(workerName) != 43 { // "worker-" (7) + UUID (36)
184+
t.Errorf("Expected worker name length to be 43, got: %d", len(workerName))
185+
}
186+
})
187+
188+
t.Run("CustomWorkerName", func(t *testing.T) {
189+
customWorkerName := "test-worker-123"
190+
options := &options{
191+
Options: backend.ApplyOptions(backend.WithWorkerName(customWorkerName)),
192+
}
193+
workerName := getWorkerName(options)
194+
195+
if workerName != customWorkerName {
196+
t.Errorf("Expected worker name to be '%s', got: %s", customWorkerName, workerName)
197+
}
198+
})
199+
200+
t.Run("EmptyWorkerNameUsesDefault", func(t *testing.T) {
201+
options := &options{
202+
Options: backend.ApplyOptions(backend.WithWorkerName("")),
203+
}
204+
workerName := getWorkerName(options)
205+
206+
// Empty worker name should fall back to UUID generation
207+
if !strings.Contains(workerName, "worker-") {
208+
t.Errorf("Expected worker name to contain 'worker-', got: %s", workerName)
209+
}
210+
if len(workerName) != 43 { // "worker-" (7) + UUID (36)
211+
t.Errorf("Expected worker name length to be 43, got: %d", len(workerName))
212+
}
213+
})
214+
}

backend/options.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ type Options struct {
4646

4747
// MaxHistorySize is the maximum size of a workflow history. If a workflow exceeds this size, it will be failed.
4848
MaxHistorySize int64
49+
50+
// WorkerName allows setting a custom worker name. If not set, backends will generate a default name.
51+
WorkerName string
4952
}
5053

5154
var DefaultOptions Options = Options{
@@ -115,6 +118,12 @@ func WithMaxHistorySize(size int64) BackendOption {
115118
}
116119
}
117120

121+
func WithWorkerName(workerName string) BackendOption {
122+
return func(o *Options) {
123+
o.WorkerName = workerName
124+
}
125+
}
126+
118127
func ApplyOptions(opts ...BackendOption) *Options {
119128
options := DefaultOptions
120129

backend/redis/queue.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,22 @@ type KeyInfo struct {
4444
SetKey string
4545
}
4646

47-
func newTaskQueue[T any](ctx context.Context, rdb redis.UniversalClient, keyPrefix string, tasktype string) (*taskQueue[T], error) {
47+
func newTaskQueue[T any](ctx context.Context, rdb redis.UniversalClient, keyPrefix string, tasktype string, workerName string) (*taskQueue[T], error) {
4848
// Ensure the key prefix ends with a colon
4949
if keyPrefix != "" && keyPrefix[len(keyPrefix)-1] != ':' {
5050
keyPrefix += ":"
5151
}
5252

53+
// Use provided worker name or generate UUID if empty
54+
if workerName == "" {
55+
workerName = uuid.NewString()
56+
}
57+
5358
tq := &taskQueue[T]{
5459
keyPrefix: keyPrefix,
5560
tasktype: tasktype,
5661
groupName: "task-workers",
57-
workerName: uuid.NewString(),
62+
workerName: workerName,
5863
queueSetKey: fmt.Sprintf("%s%s:queues", keyPrefix, tasktype),
5964
}
6065

backend/redis/queue_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func Test_TaskQueue(t *testing.T) {
106106

107107
ctx := context.Background()
108108

109-
q, err := newTaskQueue[foo](context.Background(), client, "prefix", taskType)
109+
q, err := newTaskQueue[foo](context.Background(), client, "prefix", taskType, "")
110110
require.NoError(t, err)
111111

112112
_, err = client.Pipelined(ctx, func(p redis.Pipeliner) error {
@@ -135,7 +135,7 @@ func Test_TaskQueue(t *testing.T) {
135135
})
136136
require.NoError(t, err)
137137

138-
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType)
138+
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType, "")
139139
require.NoError(t, err)
140140

141141
// Dequeue using second worker
@@ -148,7 +148,7 @@ func Test_TaskQueue(t *testing.T) {
148148
{
149149
name: "Complete removes task",
150150
f: func(t *testing.T, q *taskQueue[any]) {
151-
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType)
151+
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType, "")
152152

153153
ctx := context.Background()
154154

@@ -182,7 +182,7 @@ func Test_TaskQueue(t *testing.T) {
182182
type taskData struct {
183183
Count int `json:"count"`
184184
}
185-
q, _ := newTaskQueue[taskData](context.Background(), client, "prefix", taskType)
185+
q, _ := newTaskQueue[taskData](context.Background(), client, "prefix", taskType, "")
186186

187187
ctx := context.Background()
188188

@@ -193,7 +193,7 @@ func Test_TaskQueue(t *testing.T) {
193193
})
194194
require.NoError(t, err)
195195

196-
q2, _ := newTaskQueue[taskData](context.Background(), client, "prefix", taskType)
196+
q2, _ := newTaskQueue[taskData](context.Background(), client, "prefix", taskType, "")
197197
require.NoError(t, err)
198198

199199
task, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout)
@@ -221,7 +221,7 @@ func Test_TaskQueue(t *testing.T) {
221221
require.NoError(t, err)
222222

223223
// Create second worker (with different name)
224-
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType)
224+
q2, _ := newTaskQueue[any](context.Background(), client, "prefix", taskType, "")
225225
require.NoError(t, err)
226226

227227
task, err := q2.Dequeue(ctx, client, []workflow.Queue{workflow.QueueDefault}, lockTimeout, blockTimeout)
@@ -281,7 +281,7 @@ func Test_TaskQueue(t *testing.T) {
281281

282282
ctx := context.Background()
283283

284-
q, err := newTaskQueue[any](ctx, client, "prefix", taskType)
284+
q, err := newTaskQueue[any](ctx, client, "prefix", taskType, "")
285285
require.NoError(t, err)
286286

287287
q.Prepare(ctx, client, []workflow.Queue{workflow.QueueDefault})

backend/redis/redis.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ func NewRedisBackend(client redis.UniversalClient, opts ...RedisBackendOption) (
4242

4343
ctx := context.Background()
4444

45-
workflowQueue, err := newTaskQueue[workflowData](ctx, client, options.KeyPrefix, "workflows")
45+
workflowQueue, err := newTaskQueue[workflowData](ctx, client, options.KeyPrefix, "workflows", options.WorkerName)
4646
if err != nil {
4747
return nil, fmt.Errorf("creating workflow task queue: %w", err)
4848
}
4949

50-
activityQueue, err := newTaskQueue[activityData](ctx, client, options.KeyPrefix, "activities")
50+
activityQueue, err := newTaskQueue[activityData](ctx, client, options.KeyPrefix, "activities", options.WorkerName)
5151
if err != nil {
5252
return nil, fmt.Errorf("creating activity task queue: %w", err)
5353
}

backend/sqlite/sqlite.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func newSqliteBackend(dsn string, opts ...option) *sqliteBackend {
7676

7777
b := &sqliteBackend{
7878
db: db,
79-
workerName: fmt.Sprintf("worker-%v", uuid.NewString()),
79+
workerName: getWorkerName(options),
8080
options: options,
8181
}
8282

@@ -862,3 +862,11 @@ func (sb *sqliteBackend) ExtendActivityTask(ctx context.Context, task *backend.A
862862

863863
return tx.Commit()
864864
}
865+
866+
// getWorkerName returns the worker name from options, or generates a UUID-based name if not set.
867+
func getWorkerName(options *options) string {
868+
if options.Options.WorkerName != "" {
869+
return options.Options.WorkerName
870+
}
871+
return fmt.Sprintf("worker-%v", uuid.NewString())
872+
}

backend/sqlite/sqlite_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,40 @@ func Test_EndToEndSqliteBackend(t *testing.T) {
3636
require.NoError(t, b.Close())
3737
})
3838
}
39+
40+
func Test_SqliteBackend_WorkerName(t *testing.T) {
41+
t.Run("DefaultWorkerName", func(t *testing.T) {
42+
backend := NewInMemoryBackend()
43+
defer backend.Close()
44+
45+
// The default worker name should be in the format "worker-<uuid>"
46+
require.Contains(t, backend.workerName, "worker-")
47+
require.Len(t, backend.workerName, 43) // "worker-" (7) + UUID (36)
48+
})
49+
50+
t.Run("CustomWorkerName", func(t *testing.T) {
51+
customWorkerName := "test-worker-123"
52+
backend := NewInMemoryBackend(WithBackendOptions(backend.WithWorkerName(customWorkerName)))
53+
defer backend.Close()
54+
55+
require.Equal(t, customWorkerName, backend.workerName)
56+
})
57+
58+
t.Run("EmptyWorkerNameUsesDefault", func(t *testing.T) {
59+
backend := NewInMemoryBackend(WithBackendOptions(backend.WithWorkerName("")))
60+
defer backend.Close()
61+
62+
// Empty worker name should fall back to UUID generation
63+
require.Contains(t, backend.workerName, "worker-")
64+
require.Len(t, backend.workerName, 43) // "worker-" (7) + UUID (36)
65+
})
66+
67+
t.Run("CustomWorkerNameIsUsedInDatabase", func(t *testing.T) {
68+
customWorkerName := "integration-test-worker"
69+
backend := NewInMemoryBackend(WithBackendOptions(backend.WithWorkerName(customWorkerName)))
70+
defer backend.Close()
71+
72+
// Verify the worker name is stored correctly
73+
require.Equal(t, customWorkerName, backend.workerName)
74+
})
75+
}

0 commit comments

Comments
 (0)