Skip to content

Commit c5ec189

Browse files
committed
Merge branch 'dev'
2 parents f529662 + dcfbb25 commit c5ec189

File tree

3 files changed

+58
-41
lines changed

3 files changed

+58
-41
lines changed

groups/group.go

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"github.com/lxzan/concurrency/internal"
77
"sync"
8+
"sync/atomic"
89
"time"
910
)
1011

@@ -13,48 +14,51 @@ const (
1314
defaultWaitTimeout = 60 * time.Second // 默认线程同步等待超时
1415
)
1516

17+
var defaultCaller Caller = func(args any, f func(any) error) error { return f(args) }
18+
1619
type (
1720
Caller func(args any, f func(any) error) error
1821

1922
Group[T any] struct {
20-
options *options
21-
mu *sync.Mutex // 锁
22-
errs []error // 错误
23-
done chan bool // 信号
24-
q []T // 任务队列
25-
taskDone int64 // 已完成任务数量
26-
taskTotal int64 // 总任务数量
27-
OnMessage func(args T) error // 任务处理
28-
OnError func(err error) // 错误处理
23+
options *options // 配置
24+
mu sync.Mutex // 锁
25+
ctx context.Context // 上下文
26+
cancelFunc context.CancelFunc // 取消函数
27+
canceled atomic.Uint32 // 是否已取消
28+
errs []error // 错误
29+
done chan bool // 完成信号
30+
q []T // 任务队列
31+
taskDone int64 // 已完成任务数量
32+
taskTotal int64 // 总任务数量
33+
OnMessage func(args T) error // 任务处理
34+
OnError func(err error) // 错误处理
2935
}
3036
)
3137

3238
// New 新建一个任务集
3339
func New[T any](opts ...Option) *Group[T] {
34-
o := &options{
35-
timeout: defaultWaitTimeout,
36-
concurrency: defaultConcurrency,
37-
caller: func(args any, f func(any) error) error { return f(args) },
38-
}
40+
o := new(options)
41+
opts = append(opts, withInitialize())
3942
for _, f := range opts {
4043
f(o)
4144
}
4245

4346
c := &Group[T]{
4447
options: o,
45-
mu: &sync.Mutex{},
4648
q: make([]T, 0),
4749
taskDone: 0,
4850
done: make(chan bool),
4951
}
52+
c.ctx, c.cancelFunc = context.WithTimeout(context.Background(), o.timeout)
5053
c.OnMessage = func(args T) error {
5154
return nil
5255
}
5356
c.OnError = func(err error) {}
57+
5458
return c
5559
}
5660

57-
func (c *Group[T]) clear() {
61+
func (c *Group[T]) clearJob() {
5862
c.mu.Lock()
5963
c.q = c.q[:0]
6064
c.mu.Unlock()
@@ -82,19 +86,21 @@ func (c *Group[T]) incrAndIsDone() bool {
8286
return ok
8387
}
8488

85-
func (c *Group[T]) hasError() bool {
89+
func (c *Group[T]) getError() error {
8690
c.mu.Lock()
8791
defer c.mu.Unlock()
88-
return len(c.errs) > 0
92+
return errors.Join(c.errs...)
93+
}
94+
95+
func (c *Group[T]) jobFunc(v any) error {
96+
if c.canceled.Load() == 1 {
97+
return nil
98+
}
99+
return c.OnMessage(v.(T))
89100
}
90101

91102
func (c *Group[T]) do(args T) {
92-
if err := c.options.caller(args, func(v any) error {
93-
if c.options.cancel && c.hasError() {
94-
return nil
95-
}
96-
return c.OnMessage(v.(T))
97-
}); err != nil {
103+
if err := c.options.caller(args, c.jobFunc); err != nil {
98104
c.mu.Lock()
99105
c.errs = append(c.errs, err)
100106
c.mu.Unlock()
@@ -119,6 +125,13 @@ func (c *Group[T]) Len() int {
119125
return x
120126
}
121127

128+
// Cancel 取消队列中剩余任务的执行
129+
func (c *Group[T]) Cancel() {
130+
if c.canceled.CompareAndSwap(0, 1) {
131+
c.cancelFunc()
132+
}
133+
}
134+
122135
// Push 往任务队列中追加任务
123136
func (c *Group[T]) Push(eles ...T) {
124137
c.mu.Lock()
@@ -148,13 +161,13 @@ func (c *Group[T]) Start() error {
148161
}
149162
}
150163

151-
ctx, cancel := context.WithTimeout(context.Background(), c.options.timeout)
152-
defer cancel()
164+
defer c.cancelFunc()
165+
153166
select {
154167
case <-c.done:
155-
return errors.Join(c.errs...)
156-
case <-ctx.Done():
157-
c.clear()
158-
return ctx.Err()
168+
return c.getError()
169+
case <-c.ctx.Done():
170+
c.clearJob()
171+
return c.ctx.Err()
159172
}
160173
}

groups/group_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestNewTaskGroup(t *testing.T) {
8282
})
8383

8484
t.Run("cancel", func(t *testing.T) {
85-
ctl := New[int](WithCancel(), WithConcurrency(1))
85+
ctl := New[int](WithConcurrency(1))
8686
ctl.Push(1, 3, 5)
8787
arr := make([]int, 0)
8888
ctl.OnMessage = func(args int) error {
@@ -96,6 +96,9 @@ func TestNewTaskGroup(t *testing.T) {
9696
return nil
9797
}
9898
}
99+
ctl.OnError = func(err error) {
100+
ctl.Cancel()
101+
}
99102
err := ctl.Start()
100103
as.Error(err)
101104
as.ElementsMatch(arr, []int{1, 3})

groups/options.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package groups
22

33
import (
4+
"github.com/lxzan/concurrency/internal"
45
"github.com/pkg/errors"
56
"runtime"
67
"time"
@@ -11,7 +12,6 @@ type options struct {
1112
timeout time.Duration
1213
concurrency int64
1314
caller Caller
14-
cancel bool
1515
}
1616

1717
type Option func(o *options)
@@ -24,16 +24,9 @@ func WithTimeout(t time.Duration) Option {
2424
}
2525

2626
// WithConcurrency 设置最大并发
27-
func WithConcurrency(n int64) Option {
27+
func WithConcurrency(n uint32) Option {
2828
return func(o *options) {
29-
o.concurrency = n
30-
}
31-
}
32-
33-
// WithCancel 设置遇到错误放弃执行剩余任务
34-
func WithCancel() Option {
35-
return func(o *options) {
36-
o.cancel = true
29+
o.concurrency = int64(n)
3730
}
3831
}
3932

@@ -55,3 +48,11 @@ func WithRecovery() Option {
5548
}
5649
}
5750
}
51+
52+
func withInitialize() Option {
53+
return func(o *options) {
54+
o.timeout = internal.SelectValue(o.timeout <= 0, defaultWaitTimeout, o.timeout)
55+
o.concurrency = internal.SelectValue(o.concurrency <= 0, defaultConcurrency, o.concurrency)
56+
o.caller = internal.SelectValue(o.caller == nil, defaultCaller, o.caller)
57+
}
58+
}

0 commit comments

Comments
 (0)