Skip to content

Commit b99b8ce

Browse files
committed
address comments and add unit tests
1 parent a982c04 commit b99b8ce

File tree

3 files changed

+153
-13
lines changed

3 files changed

+153
-13
lines changed

internal/internal_worker_base.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ type (
142142
logger *zap.Logger
143143
metricsScope tally.Scope
144144

145-
dynamic *worker.DynamicParams
145+
concurrency *worker.Concurrency
146146
pollerAutoScaler *pollerAutoScaler
147147
taskQueueCh chan interface{}
148148
sessionTokenBucket *sessionTokenBucket
@@ -168,18 +168,18 @@ func createPollRetryPolicy() backoff.RetryPolicy {
168168
func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope tally.Scope, sessionTokenBucket *sessionTokenBucket) *baseWorker {
169169
ctx, cancel := context.WithCancel(context.Background())
170170

171-
dynamic := &worker.DynamicParams{
171+
concurrency := &worker.Concurrency{
172172
PollerPermit: worker.NewPermit(options.pollerCount),
173173
TaskPermit: worker.NewPermit(options.maxConcurrentTask),
174174
}
175175

176176
var pollerAS *pollerAutoScaler
177177
if pollerOptions := options.pollerAutoScaler; pollerOptions.Enabled {
178-
dynamic.PollerPermit = worker.NewPermit(pollerOptions.InitCount)
178+
concurrency.PollerPermit = worker.NewPermit(pollerOptions.InitCount)
179179
pollerAS = newPollerScaler(
180180
pollerOptions,
181181
logger,
182-
dynamic.PollerPermit,
182+
concurrency.PollerPermit,
183183
)
184184
}
185185

@@ -190,7 +190,7 @@ func newBaseWorker(options baseWorkerOptions, logger *zap.Logger, metricsScope t
190190
retrier: backoff.NewConcurrentRetrier(pollOperationRetryPolicy),
191191
logger: logger.With(zapcore.Field{Key: tagWorkerType, Type: zapcore.StringType, String: options.workerType}),
192192
metricsScope: tagScope(metricsScope, tagWorkerType, options.workerType),
193-
dynamic: dynamic,
193+
concurrency: concurrency,
194194
pollerAutoScaler: pollerAS,
195195
taskQueueCh: make(chan interface{}), // no buffer, so poller only able to poll new task after previous is dispatched.
196196
limiterContext: ctx,
@@ -252,13 +252,13 @@ func (bw *baseWorker) runPoller() {
252252
select {
253253
case <-bw.shutdownCh:
254254
return
255-
case <-bw.dynamic.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit
255+
case <-bw.concurrency.TaskPermit.AcquireChan(bw.limiterContext, &bw.shutdownWG): // don't poll unless there is a task permit
256256
// TODO move to a centralized place inside the worker
257257
// emit metrics on concurrent task permit quota and current task permit count
258258
// NOTE task permit doesn't mean there is a task running, it still needs to poll until it gets a task to process
259259
// thus the metrics is only an estimated value of how many tasks are running concurrently
260-
bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.dynamic.TaskPermit.Quota()))
261-
bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.dynamic.TaskPermit.Count()))
260+
bw.metricsScope.Gauge(metrics.ConcurrentTaskQuota).Update(float64(bw.concurrency.TaskPermit.Quota()))
261+
bw.metricsScope.Gauge(metrics.PollerRequestBufferUsage).Update(float64(bw.concurrency.TaskPermit.Count()))
262262
if bw.sessionTokenBucket != nil {
263263
bw.sessionTokenBucket.waitForAvailableToken()
264264
}
@@ -339,7 +339,7 @@ func (bw *baseWorker) pollTask() {
339339
case <-bw.shutdownCh:
340340
}
341341
} else {
342-
bw.dynamic.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit
342+
bw.concurrency.TaskPermit.Release(1) // poll failed, trigger a new poll by returning a task permit
343343
}
344344
}
345345

@@ -374,7 +374,7 @@ func (bw *baseWorker) processTask(task interface{}) {
374374
}
375375

376376
if isPolledTask {
377-
bw.dynamic.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit
377+
bw.concurrency.TaskPermit.Release(1) // task processed, trigger a new poll by returning a task permit
378378
}
379379
}()
380380
err := bw.options.taskWorker.ProcessTask(task)
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ import (
3030

3131
var _ Permit = (*permit)(nil)
3232

33-
// Synchronization contains synchronization primitives for dynamic configuration.
34-
type DynamicParams struct {
33+
// Concurrency contains synchronization primitives for dynamically controlling the concurrencies in workers
34+
type Concurrency struct {
3535
PollerPermit Permit // controls concurrency of pollers
3636
TaskPermit Permit // controlls concurrency of task processings
3737
}
@@ -69,13 +69,14 @@ func (p *permit) AcquireChan(ctx context.Context, wg *sync.WaitGroup) <-chan str
6969
wg.Add(1)
7070
go func() {
7171
defer wg.Done()
72+
defer close(ch) // close channel when permit is acquired or expired
7273
if err := p.sem.Acquire(ctx, 1); err != nil {
73-
close(ch)
7474
return
7575
}
7676
select { // try to send to channel, but don't block if listener is gone
7777
case ch <- struct{}{}:
7878
default:
79+
p.sem.Release(1)
7980
}
8081
}()
8182
return ch
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright (c) 2017-2021 Uber Technologies Inc.
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a copy
4+
// of this software and associated documentation files (the "Software"), to deal
5+
// in the Software without restriction, including without limitation the rights
6+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
// copies of the Software, and to permit persons to whom the Software is
8+
// furnished to do so, subject to the following conditions:
9+
//
10+
// The above copyright notice and this permission notice shall be included in
11+
// all copies or substantial portions of the Software.
12+
//
13+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
// THE SOFTWARE.
20+
21+
package worker
22+
23+
import (
24+
"context"
25+
"sync"
26+
"testing"
27+
"time"
28+
29+
"math/rand"
30+
31+
"github.com/stretchr/testify/assert"
32+
"go.uber.org/atomic"
33+
)
34+
35+
func TestPermit_Simulation(t *testing.T) {
36+
tests := []struct{
37+
name string
38+
capacity []int // update every 50ms
39+
goroutines int // each would block on acquiring 2-6 tokens for 100ms
40+
goroutinesAcquireChan int // each would block using AcquireChan for 100ms
41+
maxTestDuration time.Duration
42+
expectFailures int
43+
expectFailuresAtLeast int
44+
} {
45+
{
46+
name: "enough permit, no blocking",
47+
maxTestDuration: 200*time.Millisecond,
48+
capacity: []int{1000},
49+
goroutines: 100,
50+
goroutinesAcquireChan: 100,
51+
expectFailures: 0,
52+
},
53+
{
54+
name: "not enough permit, blocking but all acquire",
55+
maxTestDuration: 1*time.Second,
56+
capacity: []int{200},
57+
goroutines: 500,
58+
goroutinesAcquireChan: 500,
59+
expectFailures: 0,
60+
},
61+
{
62+
name: "not enough permit for some to acquire, fail some",
63+
maxTestDuration: 100*time.Millisecond,
64+
capacity: []int{100},
65+
goroutines: 500,
66+
goroutinesAcquireChan: 500,
67+
expectFailuresAtLeast: 1,
68+
},
69+
{
70+
name: "not enough permit at beginning but due to capacity change, blocking but all acquire",
71+
maxTestDuration: 100*time.Second,
72+
capacity: []int{100, 200, 300},
73+
goroutines: 500,
74+
goroutinesAcquireChan: 500,
75+
expectFailures: 0,
76+
},
77+
{
78+
name: "not enough permit for any acquire, fail all",
79+
maxTestDuration: 1*time.Second,
80+
capacity: []int{0},
81+
goroutines: 1000,
82+
expectFailures: 1000,
83+
},
84+
}
85+
86+
for _, tt := range tests {
87+
t.Run(tt.name, func(t *testing.T) {
88+
wg := &sync.WaitGroup{}
89+
permit := NewPermit(tt.capacity[0])
90+
wg.Add(1)
91+
go func() { // update quota every 50ms
92+
defer wg.Done()
93+
for i := 1; i < len(tt.capacity); i++ {
94+
time.Sleep(50*time.Millisecond)
95+
permit.SetQuota(tt.capacity[i])
96+
}
97+
}()
98+
failures := atomic.NewInt32(0)
99+
ctx, cancel := context.WithTimeout(context.Background(), tt.maxTestDuration)
100+
defer cancel()
101+
for i := 0; i < tt.goroutines; i++ {
102+
wg.Add(1)
103+
go func() {
104+
defer wg.Done()
105+
num := rand.Intn(2)+2
106+
// num := 1
107+
if err := permit.Acquire(ctx, num); err != nil {
108+
failures.Add(1)
109+
return
110+
}
111+
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
112+
permit.Release(num)
113+
}()
114+
}
115+
for i := 0; i < tt.goroutinesAcquireChan; i++ {
116+
wg.Add(1)
117+
go func() {
118+
defer wg.Done()
119+
select {
120+
case <-permit.AcquireChan(ctx, wg):
121+
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
122+
permit.Release(1)
123+
case <-ctx.Done():
124+
failures.Add(1)
125+
}
126+
}()
127+
}
128+
129+
wg.Wait()
130+
assert.Equal(t, 0, permit.Count())
131+
if tt.expectFailuresAtLeast >0 {
132+
assert.LessOrEqual(t, tt.expectFailuresAtLeast, int(failures.Load()))
133+
} else {
134+
assert.Equal(t, tt.expectFailures, int(failures.Load()))
135+
}
136+
assert.Equal(t, tt.capacity[len(tt.capacity)-1], permit.Quota())
137+
})
138+
}
139+
}

0 commit comments

Comments
 (0)