Skip to content

Commit 3e7eae2

Browse files
authored
Merge pull request #3 from lxzan/dev
完善单元测试和边界处理
2 parents d4f4a15 + c32c890 commit 3e7eae2

File tree

6 files changed

+267
-10
lines changed

6 files changed

+267
-10
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module github.com/lxzan/concurrency
33
go 1.20
44

55
require (
6-
github.com/lxzan/dao v1.1.7
6+
github.com/lxzan/dao v1.1.12
77
github.com/pkg/errors v0.9.1
88
github.com/stretchr/testify v1.8.4
99
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
22
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3-
github.com/lxzan/dao v1.1.7 h1:I049e67buJIpr4QJ/vJbHSjKMLN4ZJlSMeK3Rq+CJl8=
4-
github.com/lxzan/dao v1.1.7/go.mod h1:5ChTIo7RSZ4upqRo16eicJ3XdJWhGwgMIsyuGLMUofM=
3+
github.com/lxzan/dao v1.1.12 h1:TMvCwhFVzZV6c9upFxXXoiPD5wDKIYgzIYoC5KE//yc=
4+
github.com/lxzan/dao v1.1.12/go.mod h1:5ChTIo7RSZ4upqRo16eicJ3XdJWhGwgMIsyuGLMUofM=
55
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
66
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
77
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

groups/group.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ package groups
33
import (
44
"context"
55
"errors"
6-
"github.com/lxzan/concurrency/internal"
76
"sync"
87
"sync/atomic"
98
"time"
9+
10+
"github.com/lxzan/concurrency/internal"
1011
)
1112

1213
const (
@@ -149,7 +150,10 @@ func (c *Group[T]) Update(f func()) {
149150

150151
// Start 启动并等待所有任务执行完成
151152
func (c *Group[T]) Start() error {
152-
var taskTotal = int64(c.Len())
153+
c.mu.Lock()
154+
var taskTotal = c.taskTotal
155+
c.mu.Unlock()
156+
153157
if taskTotal == 0 {
154158
return nil
155159
}

groups/group_test.go

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package groups
22

33
import (
4-
"github.com/pkg/errors"
5-
"github.com/stretchr/testify/assert"
64
"sync"
75
"sync/atomic"
86
"testing"
97
"time"
8+
9+
"github.com/pkg/errors"
10+
"github.com/stretchr/testify/assert"
1011
)
1112

1213
func TestNewTaskGroup(t *testing.T) {
@@ -103,4 +104,136 @@ func TestNewTaskGroup(t *testing.T) {
103104
as.Error(err)
104105
as.ElementsMatch(arr, []int{1, 3})
105106
})
107+
108+
t.Run("push after start", func(t *testing.T) {
109+
// 测试在Start之后Push任务的行为
110+
ctl := New[int](WithConcurrency(2))
111+
ctl.Push(1, 2)
112+
var processed = make([]int, 0)
113+
var mu = sync.Mutex{}
114+
ctl.OnMessage = func(args int) error {
115+
mu.Lock()
116+
processed = append(processed, args)
117+
mu.Unlock()
118+
time.Sleep(10 * time.Millisecond)
119+
return nil
120+
}
121+
122+
go func() {
123+
time.Sleep(5 * time.Millisecond)
124+
ctl.Push(3, 4) // 在Start之后Push
125+
}()
126+
127+
err := ctl.Start()
128+
as.NoError(err)
129+
// 验证所有任务都被处理
130+
as.GreaterOrEqual(len(processed), 2)
131+
})
132+
133+
t.Run("multiple errors", func(t *testing.T) {
134+
// 测试多个任务都出错的情况
135+
ctl := New[int](WithConcurrency(2))
136+
ctl.Push(1, 2, 3)
137+
ctl.OnMessage = func(args int) error {
138+
return errors.Errorf("error %d", args)
139+
}
140+
err := ctl.Start()
141+
as.Error(err)
142+
// 验证错误被正确收集
143+
as.Contains(err.Error(), "error")
144+
})
145+
146+
t.Run("concurrent push", func(t *testing.T) {
147+
// 测试并发Push
148+
ctl := New[int](WithConcurrency(4))
149+
var wg sync.WaitGroup
150+
for i := 0; i < 10; i++ {
151+
wg.Add(1)
152+
go func(n int) {
153+
defer wg.Done()
154+
ctl.Push(n)
155+
}(i)
156+
}
157+
wg.Wait()
158+
159+
var processed = int64(0)
160+
ctl.OnMessage = func(args int) error {
161+
atomic.AddInt64(&processed, 1)
162+
return nil
163+
}
164+
err := ctl.Start()
165+
as.NoError(err)
166+
as.Equal(int64(10), processed)
167+
})
168+
169+
t.Run("zero concurrency", func(t *testing.T) {
170+
// 测试并发度为0的情况(应该使用默认值)
171+
ctl := New[int](WithConcurrency(0))
172+
ctl.Push(1, 2, 3)
173+
var processed = int64(0)
174+
ctl.OnMessage = func(args int) error {
175+
atomic.AddInt64(&processed, 1)
176+
return nil
177+
}
178+
err := ctl.Start()
179+
as.NoError(err)
180+
as.Equal(int64(3), processed)
181+
})
182+
183+
t.Run("cancel before start", func(t *testing.T) {
184+
// 测试在Start之前Cancel
185+
ctl := New[int]()
186+
ctl.Push(1, 2, 3)
187+
ctl.Cancel()
188+
var processed = int64(0)
189+
ctl.OnMessage = func(args int) error {
190+
atomic.AddInt64(&processed, 1)
191+
return nil
192+
}
193+
_ = ctl.Start()
194+
// Cancel后任务应该不会执行
195+
as.Equal(int64(0), processed)
196+
// 由于任务被取消,可能会超时或返回错误
197+
})
198+
199+
t.Run("update deadlock check", func(t *testing.T) {
200+
// 测试Update方法不会导致死锁
201+
ctl := New[int](WithConcurrency(2))
202+
ctl.Push(1, 2, 3, 4, 5)
203+
var processed = int64(0)
204+
ctl.OnMessage = func(args int) error {
205+
ctl.Update(func() {
206+
processed++
207+
})
208+
return nil
209+
}
210+
err := ctl.Start()
211+
as.NoError(err)
212+
as.Equal(int64(5), processed)
213+
})
214+
215+
t.Run("task count accuracy", func(t *testing.T) {
216+
// 测试任务计数准确性
217+
ctl := New[int](WithConcurrency(2))
218+
ctl.Push(1, 2, 3, 4, 5)
219+
var processed = int64(0)
220+
ctl.OnMessage = func(args int) error {
221+
atomic.AddInt64(&processed, 1)
222+
return nil
223+
}
224+
err := ctl.Start()
225+
as.NoError(err)
226+
as.Equal(int64(5), processed)
227+
})
228+
229+
t.Run("len accuracy", func(t *testing.T) {
230+
// 测试Len方法的准确性
231+
ctl := New[int]()
232+
as.Equal(0, ctl.Len())
233+
ctl.Push(1, 2, 3)
234+
as.Equal(3, ctl.Len())
235+
ctl.Start()
236+
// Start后队列应该为空
237+
as.Equal(0, ctl.Len())
238+
})
106239
}

queues/multiple_queue.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,8 @@ func (c *multipleQueue) Stop(ctx context.Context) error {
5353
}(c.qs[i])
5454
}
5555
wg.Wait()
56-
return err.Load().err
56+
if e := err.Load(); e != nil {
57+
return e.err
58+
}
59+
return nil
5760
}

queues/queue_test.go

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ package queues
22

33
import (
44
"context"
5-
"github.com/lxzan/concurrency/logs"
6-
"github.com/stretchr/testify/assert"
75
"sync"
86
"sync/atomic"
97
"testing"
108
"time"
9+
10+
"github.com/lxzan/concurrency/logs"
11+
"github.com/stretchr/testify/assert"
1112
)
1213

1314
func TestSingleQueue(t *testing.T) {
@@ -171,4 +172,120 @@ func TestMultiQueue(t *testing.T) {
171172
q.Stop(context.Background())
172173
q.Push(func() {})
173174
})
175+
176+
t.Run("push after stop", func(t *testing.T) {
177+
// 测试Stop后Push的行为
178+
q := New(WithConcurrency(2))
179+
var processed = int64(0)
180+
q.Push(func() {
181+
atomic.AddInt64(&processed, 1)
182+
})
183+
q.Stop(context.Background())
184+
as.Equal(int64(1), processed)
185+
186+
// Stop后Push应该被忽略
187+
q.Push(func() {
188+
atomic.AddInt64(&processed, 1)
189+
})
190+
time.Sleep(50 * time.Millisecond)
191+
as.Equal(int64(1), processed) // 应该还是1,因为新任务被忽略
192+
})
193+
194+
t.Run("concurrent push", func(t *testing.T) {
195+
// 测试并发Push
196+
q := New(WithConcurrency(4))
197+
var processed = int64(0)
198+
var wg sync.WaitGroup
199+
for i := 0; i < 100; i++ {
200+
wg.Add(1)
201+
go func() {
202+
defer wg.Done()
203+
q.Push(func() {
204+
atomic.AddInt64(&processed, 1)
205+
})
206+
}()
207+
}
208+
wg.Wait()
209+
q.Stop(context.Background())
210+
as.Equal(int64(100), processed)
211+
})
212+
213+
t.Run("zero concurrency", func(t *testing.T) {
214+
// 测试并发度为0的情况
215+
q := New(WithConcurrency(0))
216+
var processed = int64(0)
217+
q.Push(func() {
218+
atomic.AddInt64(&processed, 1)
219+
})
220+
q.Stop(context.Background())
221+
as.Equal(int64(1), processed)
222+
})
223+
224+
t.Run("len accuracy", func(t *testing.T) {
225+
// 测试Len方法的准确性
226+
q := New(WithConcurrency(2))
227+
as.Equal(0, q.Len())
228+
q.Push(func() {})
229+
q.Push(func() {})
230+
// 由于有并发执行,Len可能为0或更小
231+
time.Sleep(10 * time.Millisecond)
232+
q.Stop(context.Background())
233+
as.Equal(0, q.Len())
234+
})
235+
236+
t.Run("stop with empty queue", func(t *testing.T) {
237+
// 测试空队列Stop
238+
q := New()
239+
err := q.Stop(context.Background())
240+
as.NoError(err)
241+
})
242+
243+
t.Run("stop multiple times", func(t *testing.T) {
244+
// 测试多次Stop
245+
q := New()
246+
err1 := q.Stop(context.Background())
247+
err2 := q.Stop(context.Background())
248+
as.NoError(err1)
249+
as.NoError(err2)
250+
})
251+
252+
t.Run("context cancel during stop", func(t *testing.T) {
253+
// 测试Stop时context被取消
254+
q := New(WithConcurrency(1))
255+
q.Push(func() {
256+
time.Sleep(200 * time.Millisecond)
257+
})
258+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
259+
defer cancel()
260+
err := q.Stop(ctx)
261+
as.Error(err)
262+
})
263+
264+
t.Run("multiple queue error handling", func(t *testing.T) {
265+
// 测试多队列错误处理
266+
q := New(WithSharding(4), WithConcurrency(1), WithTimeout(50*time.Millisecond))
267+
// 添加一些会超时的任务
268+
for i := 0; i < 8; i++ {
269+
q.Push(func() {
270+
time.Sleep(200 * time.Millisecond)
271+
})
272+
}
273+
err := q.Stop(context.Background())
274+
// 应该返回超时错误
275+
as.Error(err)
276+
})
277+
278+
t.Run("multiple queue all success", func(t *testing.T) {
279+
// 测试所有队列都成功的情况
280+
q := New(WithSharding(4), WithConcurrency(2))
281+
var processed = int64(0)
282+
for i := 0; i < 20; i++ {
283+
q.Push(func() {
284+
atomic.AddInt64(&processed, 1)
285+
})
286+
}
287+
err := q.Stop(context.Background())
288+
as.NoError(err)
289+
as.Equal(int64(20), processed)
290+
})
174291
}

0 commit comments

Comments
 (0)