Skip to content

Commit 6f63617

Browse files
committed
refactor: Separate TaskPool abstraction from Scheduler
1 parent 63d68a2 commit 6f63617

8 files changed

+1137
-74
lines changed

common/task/interface.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ type (
4040
TrySubmit(task PriorityTask) (bool, error)
4141
}
4242

43+
// TaskPool manages task storage and determines scheduling order.
44+
// Different implementations provide different scheduling algorithms.
45+
TaskPool interface {
46+
common.Daemon
47+
48+
// Submit adds a task to the pool, blocks if pool is full
49+
Submit(task PriorityTask) error
50+
51+
// TrySubmit attempts to add a task, returns immediately if pool is full
52+
TrySubmit(task PriorityTask) (bool, error)
53+
54+
// GetNextTask retrieves the next task according to the pool's scheduling algorithm
55+
// Returns (task, true) if a task is available, (nil, false) if no task is ready
56+
GetNextTask() (PriorityTask, bool)
57+
58+
// Len returns the number of tasks currently in the pool
59+
Len() int
60+
}
61+
4362
// SchedulerType respresents the type of the task scheduler implementation
4463
SchedulerType int
4564

common/task/interface_mock.go

Lines changed: 108 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/task/weighted_channel_pool.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any](
8383
timeSource clock.TimeSource,
8484
options WeightedRoundRobinChannelPoolOptions,
8585
) *WeightedRoundRobinChannelPool[K, V] {
86-
return &WeightedRoundRobinChannelPool[K, V]{
86+
wrr := &WeightedRoundRobinChannelPool[K, V]{
8787
bufferSize: options.BufferSize,
8888
idleChannelTTLInSeconds: options.IdleChannelTTLInSeconds,
8989
logger: logger,
@@ -92,6 +92,8 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any](
9292
channelMap: make(map[K]*weightedChannel[V]),
9393
shutdownCh: make(chan struct{}),
9494
}
95+
wrr.iwrrSchedule.Store(make([]chan V, 0))
96+
return wrr
9597
}
9698

9799
func (p *WeightedRoundRobinChannelPool[K, V]) Start() {
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// Copyright (c) 2020 Uber Technologies, Inc.
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a copy
4+
// of this software and associated documentation files (the "Software"), to deal
5+
// in the Software without restriction, including without limitation the rights
6+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
// copies of the Software, and to permit persons to whom the Software is
8+
// furnished to do so, subject to the following conditions:
9+
//
10+
// The above copyright notice and this permission notice shall be included in
11+
// all copies or substantial portions of the Software.
12+
//
13+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
// THE SOFTWARE.
20+
21+
package task
22+
23+
import (
24+
"context"
25+
"sync"
26+
"sync/atomic"
27+
28+
"github.com/uber/cadence/common"
29+
"github.com/uber/cadence/common/clock"
30+
"github.com/uber/cadence/common/log"
31+
"github.com/uber/cadence/common/metrics"
32+
)
33+
34+
type weightedRoundRobinTaskPool[K comparable] struct {
35+
sync.Mutex
36+
status int32
37+
pool *WeightedRoundRobinChannelPool[K, PriorityTask]
38+
ctx context.Context
39+
cancel context.CancelFunc
40+
options *WeightedRoundRobinTaskPoolOptions[K]
41+
logger log.Logger
42+
taskCount atomic.Int64 // O(1) task count tracking
43+
schedule []chan PriorityTask // Current schedule
44+
scheduleIndex int // Current position in schedule
45+
}
46+
47+
// NewWeightedRoundRobinTaskPool creates a new WRR task pool
48+
func NewWeightedRoundRobinTaskPool[K comparable](
49+
logger log.Logger,
50+
metricsClient metrics.Client,
51+
timeSource clock.TimeSource,
52+
options *WeightedRoundRobinTaskPoolOptions[K],
53+
) TaskPool {
54+
metricsScope := metricsClient.Scope(metrics.TaskSchedulerScope)
55+
ctx, cancel := context.WithCancel(context.Background())
56+
57+
pool := &weightedRoundRobinTaskPool[K]{
58+
status: common.DaemonStatusInitialized,
59+
pool: NewWeightedRoundRobinChannelPool[K, PriorityTask](
60+
logger,
61+
metricsScope,
62+
timeSource,
63+
WeightedRoundRobinChannelPoolOptions{
64+
BufferSize: options.QueueSize,
65+
IdleChannelTTLInSeconds: defaultIdleChannelTTLInSeconds,
66+
}),
67+
ctx: ctx,
68+
cancel: cancel,
69+
options: options,
70+
logger: logger,
71+
}
72+
73+
return pool
74+
}
75+
76+
func (p *weightedRoundRobinTaskPool[K]) Start() {
77+
if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) {
78+
return
79+
}
80+
81+
p.pool.Start()
82+
83+
p.Lock()
84+
p.schedule = nil
85+
p.scheduleIndex = 0
86+
p.Unlock()
87+
88+
p.logger.Info("Weighted round robin task pool started.")
89+
}
90+
91+
func (p *weightedRoundRobinTaskPool[K]) Stop() {
92+
if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusStarted, common.DaemonStatusStopped) {
93+
return
94+
}
95+
96+
p.cancel()
97+
p.pool.Stop()
98+
99+
// Drain all tasks and nack them, updating the counter
100+
taskChs := p.pool.GetAllChannels()
101+
for _, taskCh := range taskChs {
102+
p.drainAndNackTasks(taskCh)
103+
}
104+
105+
p.logger.Info("Weighted round robin task pool stopped.")
106+
}
107+
108+
func (p *weightedRoundRobinTaskPool[K]) Submit(task PriorityTask) error {
109+
if p.isStopped() {
110+
return ErrTaskSchedulerClosed
111+
}
112+
113+
key := p.options.TaskToChannelKeyFn(task)
114+
weight := p.options.ChannelKeyToWeightFn(key)
115+
taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight)
116+
defer releaseFn()
117+
118+
select {
119+
case taskCh <- task:
120+
p.taskCount.Add(1)
121+
if p.isStopped() {
122+
p.drainAndNackTasks(taskCh)
123+
}
124+
return nil
125+
case <-p.ctx.Done():
126+
return ErrTaskSchedulerClosed
127+
}
128+
}
129+
130+
func (p *weightedRoundRobinTaskPool[K]) TrySubmit(task PriorityTask) (bool, error) {
131+
if p.isStopped() {
132+
return false, ErrTaskSchedulerClosed
133+
}
134+
135+
key := p.options.TaskToChannelKeyFn(task)
136+
weight := p.options.ChannelKeyToWeightFn(key)
137+
taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight)
138+
defer releaseFn()
139+
140+
select {
141+
case taskCh <- task:
142+
p.taskCount.Add(1)
143+
if p.isStopped() {
144+
p.drainAndNackTasks(taskCh)
145+
}
146+
return true, nil
147+
case <-p.ctx.Done():
148+
return false, ErrTaskSchedulerClosed
149+
default:
150+
return false, nil
151+
}
152+
}
153+
154+
func (p *weightedRoundRobinTaskPool[K]) GetNextTask() (PriorityTask, bool) {
155+
if p.isStopped() {
156+
return nil, false
157+
}
158+
159+
p.Lock()
160+
defer p.Unlock()
161+
162+
// Get a fresh schedule if we don't have one or if we've reached the end
163+
if p.schedule == nil || p.scheduleIndex >= len(p.schedule) {
164+
p.schedule = p.pool.GetSchedule()
165+
p.scheduleIndex = 0
166+
167+
if len(p.schedule) == 0 {
168+
return nil, false
169+
}
170+
}
171+
172+
// Try to get a task starting from the current index
173+
startIndex := p.scheduleIndex
174+
for {
175+
select {
176+
case task := <-p.schedule[p.scheduleIndex]:
177+
// Found a task, advance index and return
178+
p.scheduleIndex++
179+
p.taskCount.Add(-1)
180+
return task, true
181+
case <-p.ctx.Done():
182+
return nil, false
183+
default:
184+
// No task in this channel, try next
185+
p.scheduleIndex++
186+
187+
// If we've reached the end, get a fresh schedule and continue
188+
if p.scheduleIndex >= len(p.schedule) {
189+
p.schedule = p.pool.GetSchedule()
190+
p.scheduleIndex = 0
191+
192+
if len(p.schedule) == 0 {
193+
return nil, false
194+
}
195+
}
196+
197+
// If we've made a full loop without finding a task, return false
198+
if p.scheduleIndex == startIndex {
199+
return nil, false
200+
}
201+
}
202+
}
203+
}
204+
205+
func (p *weightedRoundRobinTaskPool[K]) Len() int {
206+
return int(p.taskCount.Load())
207+
}
208+
209+
func (p *weightedRoundRobinTaskPool[K]) drainAndNackTasks(taskCh <-chan PriorityTask) {
210+
for {
211+
select {
212+
case task := <-taskCh:
213+
p.taskCount.Add(-1)
214+
task.Nack(nil)
215+
default:
216+
return
217+
}
218+
}
219+
}
220+
221+
func (p *weightedRoundRobinTaskPool[K]) isStopped() bool {
222+
return atomic.LoadInt32(&p.status) == common.DaemonStatusStopped
223+
}

0 commit comments

Comments
 (0)