Skip to content

Commit e39af03

Browse files
ldmonsterCopilot
andcommitted
[fix] Fix possibly deadlock (#818)
Signed-off-by: Pavel Okhlopkov <[email protected]> Co-authored-by: Copilot <[email protected]> Signed-off-by: Pavel Okhlopkov <[email protected]>
1 parent f0cf2fb commit e39af03

File tree

10 files changed

+307
-113
lines changed

10 files changed

+307
-113
lines changed

pkg/shell-operator/combine_binding_context.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (op *ShellOperator) combineBindingContextForHook(tqs *queue.TaskQueueSet, q
3838

3939
otherTasks := make([]task.Task, 0)
4040
stopIterate := false
41-
q.Iterate(func(tsk task.Task) {
41+
q.IterateSnapshot(func(tsk task.Task) {
4242
if stopIterate {
4343
return
4444
}
@@ -91,7 +91,7 @@ func (op *ShellOperator) combineBindingContextForHook(tqs *queue.TaskQueueSet, q
9191
}
9292

9393
// Delete tasks with false in tasksFilter map
94-
tqs.GetByName(t.GetQueueName()).Filter(func(tsk task.Task) bool {
94+
tqs.GetByName(t.GetQueueName()).DeleteFunc(func(tsk task.Task) bool {
9595
if v, ok := tasksFilter[tsk.GetId()]; ok {
9696
return v
9797
}

pkg/shell-operator/manager_events_handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (m *ManagerEventsHandler) Start() {
8888

8989
m.taskQueues.DoWithLock(func(tqs *queue.TaskQueueSet) {
9090
for _, resTask := range tailTasks {
91-
if q := tqs.Queues[resTask.GetQueueName()]; q == nil {
91+
if q, ok := tqs.Queues.Get(resTask.GetQueueName()); !ok {
9292
log.Error("Possible bug!!! Got task for queue but queue is not created yet.",
9393
slog.String("queueName", resTask.GetQueueName()),
9494
slog.String("description", resTask.GetDescription()))

pkg/shell-operator/operator.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ func (op *ShellOperator) CombineBindingContextForHook(q *queue.TaskQueue, t task
731731

732732
otherTasks := make([]task.Task, 0)
733733
stopIterate := false
734-
q.Iterate(func(tsk task.Task) {
734+
q.IterateSnapshot(func(tsk task.Task) {
735735
if stopIterate {
736736
return
737737
}
@@ -782,7 +782,7 @@ func (op *ShellOperator) CombineBindingContextForHook(q *queue.TaskQueue, t task
782782
}
783783

784784
// Delete tasks with false in tasksFilter map
785-
op.TaskQueues.GetByName(t.GetQueueName()).Filter(func(tsk task.Task) bool {
785+
op.TaskQueues.GetByName(t.GetQueueName()).DeleteFunc(func(tsk task.Task) bool {
786786
if v, ok := tasksFilter[tsk.GetId()]; ok {
787787
return v
788788
}
@@ -949,7 +949,7 @@ func (op *ShellOperator) runMetrics() {
949949
// task queue length
950950
go func() {
951951
for {
952-
op.TaskQueues.Iterate(func(queue *queue.TaskQueue) {
952+
op.TaskQueues.Iterate(context.TODO(), func(_ context.Context, queue *queue.TaskQueue) {
953953
queueLen := float64(queue.Length())
954954
op.MetricStorage.GaugeSet("{PREFIX}tasks_queue_length", queueLen, map[string]string{"queue": queue.Name})
955955
})

pkg/shell-operator/operator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func Test_Operator_startup_tasks(t *testing.T) {
6666
}
6767

6868
i := 0
69-
op.TaskQueues.GetMain().Iterate(func(tsk task.Task) {
69+
op.TaskQueues.GetMain().IterateSnapshot(func(tsk task.Task) {
7070
// Stop checking if no expects left.
7171
if i >= len(expectTasks) {
7272
return

pkg/task/dump/dump.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dump
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"sort"
@@ -78,7 +79,7 @@ func TaskQueues(tqs *queue.TaskQueueSet, format string, showEmpty bool) interfac
7879
tasksCount := 0
7980
mainTasksCount := 0
8081

81-
tqs.Iterate(func(queue *queue.TaskQueue) {
82+
tqs.Iterate(context.TODO(), func(_ context.Context, queue *queue.TaskQueue) {
8283
if queue == nil {
8384
return
8485
}
@@ -160,7 +161,7 @@ func getTasksForQueue(q *queue.TaskQueue) []dumpTask {
160161
tasks := make([]dumpTask, 0, q.Length())
161162

162163
index := 1
163-
q.Iterate(func(task task.Task) {
164+
q.IterateSnapshot(func(task task.Task) {
164165
tasks = append(tasks, dumpTask{
165166
Index: index,
166167
Description: task.GetDescription(),

pkg/task/queue/queue_set.go

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,58 @@ import (
1313

1414
const MainQueueName = "main"
1515

16+
// queueStorage is a thread-safe storage for task queues with basic Get/Set/Delete operations
17+
type queueStorage struct {
18+
mu sync.RWMutex
19+
queues map[string]*TaskQueue
20+
}
21+
22+
func newQueueStorage() *queueStorage {
23+
return &queueStorage{
24+
queues: make(map[string]*TaskQueue),
25+
}
26+
}
27+
28+
// Get retrieves a queue by name, returns nil if not found
29+
func (qs *queueStorage) Get(name string) (*TaskQueue, bool) {
30+
qs.mu.RLock()
31+
defer qs.mu.RUnlock()
32+
queue, exists := qs.queues[name]
33+
return queue, exists
34+
}
35+
36+
// List retrieves all queues
37+
func (qs *queueStorage) List() []*TaskQueue {
38+
qs.mu.RLock()
39+
defer qs.mu.RUnlock()
40+
41+
queues := make([]*TaskQueue, 0, len(qs.queues))
42+
for _, queue := range qs.queues {
43+
queues = append(queues, queue)
44+
}
45+
46+
return queues
47+
}
48+
49+
// Set stores a queue with the given name
50+
func (qs *queueStorage) Set(name string, queue *TaskQueue) {
51+
qs.mu.Lock()
52+
defer qs.mu.Unlock()
53+
qs.queues[name] = queue
54+
}
55+
56+
// Delete removes a queue by name
57+
func (qs *queueStorage) Delete(name string) {
58+
qs.mu.Lock()
59+
defer qs.mu.Unlock()
60+
delete(qs.queues, name)
61+
}
62+
63+
// Len returns the number of tasks in a queue by name
64+
func (qs *queueStorage) Len() int {
65+
return len(qs.queues)
66+
}
67+
1668
// TaskQueueSet is a manager for a set of named queues
1769
type TaskQueueSet struct {
1870
MainName string
@@ -23,12 +75,12 @@ type TaskQueueSet struct {
2375
cancel context.CancelFunc
2476

2577
m sync.RWMutex
26-
Queues map[string]*TaskQueue
78+
Queues *queueStorage
2779
}
2880

2981
func NewTaskQueueSet() *TaskQueueSet {
3082
return &TaskQueueSet{
31-
Queues: make(map[string]*TaskQueue),
83+
Queues: newQueueStorage(),
3284
MainName: MainQueueName,
3385
}
3486
}
@@ -61,18 +113,14 @@ func (tqs *TaskQueueSet) StartMain(ctx context.Context) {
61113
}
62114

63115
func (tqs *TaskQueueSet) Start(ctx context.Context) {
64-
tqs.m.RLock()
65-
for _, q := range tqs.Queues {
66-
q.Start(ctx)
67-
}
68-
69-
tqs.m.RUnlock()
116+
tqs.Iterate(ctx, func(ctx context.Context, queue *TaskQueue) {
117+
queue.Start(ctx)
118+
})
70119
}
71120

121+
// Add register a new queue for TaskQueueSet.
72122
func (tqs *TaskQueueSet) Add(queue *TaskQueue) {
73-
tqs.m.Lock()
74-
tqs.Queues[queue.Name] = queue
75-
tqs.m.Unlock()
123+
tqs.Queues.Set(queue.Name, queue)
76124
}
77125

78126
func (tqs *TaskQueueSet) NewNamedQueue(name string, handler func(ctx context.Context, t task.Task) TaskResult, opts ...TaskQueueOption) {
@@ -91,92 +139,88 @@ func (tqs *TaskQueueSet) NewNamedQueue(name string, handler func(ctx context.Con
91139
q.logger = log.NewLogger().Named("task_queue")
92140
}
93141

94-
tqs.m.Lock()
95-
tqs.Queues[name] = q
96-
tqs.m.Unlock()
142+
tqs.Queues.Set(q.Name, q)
97143
}
98144

99145
func (tqs *TaskQueueSet) GetByName(name string) *TaskQueue {
100-
tqs.m.RLock()
101-
defer tqs.m.RUnlock()
102-
ts, exists := tqs.Queues[name]
103-
if exists {
104-
return ts
146+
q, ok := tqs.Queues.Get(name)
147+
if !ok {
148+
return nil
105149
}
106-
return nil
150+
151+
return q
107152
}
108153

109154
func (tqs *TaskQueueSet) GetMain() *TaskQueue {
110155
return tqs.GetByName(tqs.MainName)
111156
}
112157

113-
/*
114-
taskQueueSet.DoWithLock(func(tqs *TaskQueueSet){
115-
tqs.GetMain().Pop()
116-
})
117-
*/
118158
func (tqs *TaskQueueSet) DoWithLock(fn func(tqs *TaskQueueSet)) {
119159
tqs.m.Lock()
120160
defer tqs.m.Unlock()
161+
121162
if fn != nil {
122163
fn(tqs)
123164
}
124165
}
125166

126167
// Iterate run doFn for every task.
127-
func (tqs *TaskQueueSet) Iterate(doFn func(queue *TaskQueue)) {
168+
func (tqs *TaskQueueSet) Iterate(ctx context.Context, doFn func(ctx context.Context, queue *TaskQueue)) {
128169
if doFn == nil {
129170
return
130171
}
131172

132173
tqs.m.RLock()
133174
defer tqs.m.RUnlock()
134-
if len(tqs.Queues) == 0 {
175+
if tqs.Queues.Len() == 0 {
135176
return
136177
}
137178

138179
main := tqs.GetMain()
139180
if main != nil {
140-
doFn(main)
181+
doFn(ctx, main)
141182
}
142183
// TODO sort names
143184

144-
for _, q := range tqs.Queues {
185+
for _, q := range tqs.Queues.List() {
145186
if q.Name != tqs.MainName {
146-
doFn(q)
187+
doFn(ctx, q)
147188
}
148189
}
149190
}
150191

151192
func (tqs *TaskQueueSet) Remove(name string) {
152-
tqs.m.Lock()
153-
ts, exists := tqs.Queues[name]
193+
ts, exists := tqs.Queues.Get(name)
154194
if exists {
155195
ts.Stop()
156196
}
157197

158-
delete(tqs.Queues, name)
159-
tqs.m.Unlock()
198+
tqs.Queues.Delete(name)
160199
}
161200

162201
func (tqs *TaskQueueSet) WaitStopWithTimeout(timeout time.Duration) {
163202
checkTick := time.NewTicker(time.Millisecond * 100)
164203
defer checkTick.Stop()
204+
165205
timeoutTick := time.NewTicker(timeout)
166206
defer timeoutTick.Stop()
167207

168208
for {
169209
select {
170210
case <-checkTick.C:
171-
stopped := true
172-
tqs.m.RLock()
173-
for _, q := range tqs.Queues {
174-
if q.Status != "stop" {
175-
stopped = false
176-
break
211+
stopped := func() bool {
212+
tqs.m.RLock()
213+
defer tqs.m.RUnlock()
214+
215+
for _, q := range tqs.Queues.List() {
216+
if q.GetStatusType() != QueueStatusStop {
217+
return false
218+
}
177219
}
178-
}
179-
tqs.m.RUnlock()
220+
221+
return true
222+
}()
223+
180224
if stopped {
181225
return
182226
}

0 commit comments

Comments
 (0)