Skip to content

Commit 39b3df2

Browse files
joeybloggsjoeybloggs
authored andcommitted
fix race condition when using consumer hook
1 parent 860c8de commit 39b3df2

File tree

1 file changed

+48
-41
lines changed

1 file changed

+48
-41
lines changed

pool.go

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ type Pool struct {
3636
cancelled bool
3737
cancelLock sync.RWMutex
3838
consumerHook ConsumerHook
39+
once sync.Once
40+
consumers int
3941
}
4042

4143
// JobFunc is the consumable function/job you wish to run
@@ -75,49 +77,13 @@ func (j *Job) Return(result interface{}) {
7577
func NewPool(consumers int, jobs int) *Pool {
7678

7779
p := &Pool{
78-
wg: new(sync.WaitGroup),
79-
jobs: make(chan *Job, jobs),
80-
results: make(chan interface{}, jobs),
81-
cancel: make(chan struct{}),
80+
wg: new(sync.WaitGroup),
81+
jobs: make(chan *Job, jobs),
82+
results: make(chan interface{}, jobs),
83+
cancel: make(chan struct{}),
84+
consumers: consumers,
8285
}
8386

84-
for i := 0; i < consumers; i++ {
85-
go func(p *Pool) {
86-
defer func(p *Pool) {
87-
if err := recover(); err != nil {
88-
trace := make([]byte, 1<<16)
89-
n := runtime.Stack(trace, true)
90-
rerr := &ErrRecovery{
91-
s: fmt.Sprintf(errRecoveryString, err, trace[:n]),
92-
}
93-
p.results <- rerr
94-
p.Cancel()
95-
p.wg.Done()
96-
}
97-
}(p)
98-
99-
var consumerParm interface{}
100-
101-
if p.consumerHook != nil {
102-
consumerParm = p.consumerHook()
103-
}
104-
105-
for {
106-
select {
107-
case j := <-p.jobs:
108-
if reflect.ValueOf(j).IsNil() {
109-
return
110-
}
111-
112-
j.hookParam = consumerParm
113-
j.fn(j)
114-
p.wg.Done()
115-
case <-p.cancel:
116-
return
117-
}
118-
}
119-
}(p)
120-
}
12187
return p
12288
}
12389

@@ -137,6 +103,46 @@ func (p *Pool) cancelJobs() {
137103
// Queue adds a job to be processed and the params to be passed to it.
138104
func (p *Pool) Queue(fn JobFunc, params ...interface{}) {
139105

106+
p.once.Do(func() {
107+
for i := 0; i < p.consumers; i++ {
108+
go func(p *Pool) {
109+
defer func(p *Pool) {
110+
if err := recover(); err != nil {
111+
trace := make([]byte, 1<<16)
112+
n := runtime.Stack(trace, true)
113+
rerr := &ErrRecovery{
114+
s: fmt.Sprintf(errRecoveryString, err, trace[:n]),
115+
}
116+
p.results <- rerr
117+
p.Cancel()
118+
p.wg.Done()
119+
}
120+
}(p)
121+
122+
var consumerParm interface{}
123+
124+
if p.consumerHook != nil {
125+
consumerParm = p.consumerHook()
126+
}
127+
128+
for {
129+
select {
130+
case j := <-p.jobs:
131+
if reflect.ValueOf(j).IsNil() {
132+
return
133+
}
134+
135+
j.hookParam = consumerParm
136+
j.fn(j)
137+
p.wg.Done()
138+
case <-p.cancel:
139+
return
140+
}
141+
}
142+
}(p)
143+
}
144+
})
145+
140146
p.cancelLock.Lock()
141147
defer p.cancelLock.Unlock()
142148

@@ -152,6 +158,7 @@ func (p *Pool) Queue(fn JobFunc, params ...interface{}) {
152158

153159
p.wg.Add(1)
154160
p.jobs <- job
161+
155162
}
156163

157164
// Cancel cancels all jobs not already running.

0 commit comments

Comments
 (0)