Skip to content

Commit 1799798

Browse files
committed
refactor: Separate TaskPool abstraction from Scheduler
1 parent 63d68a2 commit 1799798

8 files changed

+1132
-74
lines changed

common/task/interface.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ type (
4040
TrySubmit(task PriorityTask) (bool, error)
4141
}
4242

43+
// TaskPool manages task storage and determines scheduling order.
44+
// Different implementations provide different scheduling algorithms.
45+
TaskPool interface {
46+
common.Daemon
47+
48+
// Submit adds a task to the pool, blocks if pool is full
49+
Submit(task PriorityTask) error
50+
51+
// TrySubmit attempts to add a task, returns immediately if pool is full
52+
TrySubmit(task PriorityTask) (bool, error)
53+
54+
// GetNextTask retrieves the next task according to the pool's scheduling algorithm
55+
// Returns (task, true) if a task is available, (nil, false) if no task is ready
56+
GetNextTask() (PriorityTask, bool)
57+
58+
// Len returns the number of tasks currently in the pool
59+
Len() int
60+
}
61+
4362
// SchedulerType respresents the type of the task scheduler implementation
4463
SchedulerType int
4564

common/task/interface_mock.go

Lines changed: 108 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

common/task/weighted_channel_pool.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any](
8383
timeSource clock.TimeSource,
8484
options WeightedRoundRobinChannelPoolOptions,
8585
) *WeightedRoundRobinChannelPool[K, V] {
86-
return &WeightedRoundRobinChannelPool[K, V]{
86+
wrr := &WeightedRoundRobinChannelPool[K, V]{
8787
bufferSize: options.BufferSize,
8888
idleChannelTTLInSeconds: options.IdleChannelTTLInSeconds,
8989
logger: logger,
@@ -92,6 +92,8 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any](
9292
channelMap: make(map[K]*weightedChannel[V]),
9393
shutdownCh: make(chan struct{}),
9494
}
95+
wrr.iwrrSchedule.Store(make([]chan V, 0))
96+
return wrr
9597
}
9698

9799
func (p *WeightedRoundRobinChannelPool[K, V]) Start() {
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Copyright (c) 2020 Uber Technologies, Inc.
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a copy
4+
// of this software and associated documentation files (the "Software"), to deal
5+
// in the Software without restriction, including without limitation the rights
6+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
// copies of the Software, and to permit persons to whom the Software is
8+
// furnished to do so, subject to the following conditions:
9+
//
10+
// The above copyright notice and this permission notice shall be included in
11+
// all copies or substantial portions of the Software.
12+
//
13+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
// THE SOFTWARE.
20+
21+
package task
22+
23+
import (
24+
"context"
25+
"sync"
26+
"sync/atomic"
27+
28+
"github.com/uber/cadence/common"
29+
"github.com/uber/cadence/common/clock"
30+
"github.com/uber/cadence/common/log"
31+
"github.com/uber/cadence/common/metrics"
32+
)
33+
34+
type weightedRoundRobinTaskPool[K comparable] struct {
35+
sync.Mutex
36+
status int32
37+
pool *WeightedRoundRobinChannelPool[K, PriorityTask]
38+
ctx context.Context
39+
cancel context.CancelFunc
40+
options *WeightedRoundRobinTaskPoolOptions[K]
41+
logger log.Logger
42+
taskCount atomic.Int64 // O(1) task count tracking
43+
schedule []chan PriorityTask // Current schedule
44+
scheduleIndex int // Current position in schedule
45+
}
46+
47+
// NewWeightedRoundRobinTaskPool creates a new WRR task pool
48+
func NewWeightedRoundRobinTaskPool[K comparable](
49+
logger log.Logger,
50+
metricsClient metrics.Client,
51+
timeSource clock.TimeSource,
52+
options *WeightedRoundRobinTaskPoolOptions[K],
53+
) TaskPool {
54+
metricsScope := metricsClient.Scope(metrics.TaskSchedulerScope)
55+
ctx, cancel := context.WithCancel(context.Background())
56+
57+
pool := &weightedRoundRobinTaskPool[K]{
58+
status: common.DaemonStatusInitialized,
59+
pool: NewWeightedRoundRobinChannelPool[K, PriorityTask](
60+
logger,
61+
metricsScope,
62+
timeSource,
63+
WeightedRoundRobinChannelPoolOptions{
64+
BufferSize: options.QueueSize,
65+
IdleChannelTTLInSeconds: defaultIdleChannelTTLInSeconds,
66+
}),
67+
ctx: ctx,
68+
cancel: cancel,
69+
options: options,
70+
logger: logger,
71+
}
72+
73+
return pool
74+
}
75+
76+
func (p *weightedRoundRobinTaskPool[K]) Start() {
77+
if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) {
78+
return
79+
}
80+
81+
p.pool.Start()
82+
83+
p.logger.Info("Weighted round robin task pool started.")
84+
}
85+
86+
func (p *weightedRoundRobinTaskPool[K]) Stop() {
87+
if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusStarted, common.DaemonStatusStopped) {
88+
return
89+
}
90+
91+
p.cancel()
92+
p.pool.Stop()
93+
94+
// Drain all tasks and nack them, updating the counter
95+
taskChs := p.pool.GetAllChannels()
96+
for _, taskCh := range taskChs {
97+
p.drainAndNackTasks(taskCh)
98+
}
99+
100+
p.logger.Info("Weighted round robin task pool stopped.")
101+
}
102+
103+
func (p *weightedRoundRobinTaskPool[K]) Submit(task PriorityTask) error {
104+
if p.isStopped() {
105+
return ErrTaskSchedulerClosed
106+
}
107+
108+
key := p.options.TaskToChannelKeyFn(task)
109+
weight := p.options.ChannelKeyToWeightFn(key)
110+
taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight)
111+
defer releaseFn()
112+
113+
select {
114+
case taskCh <- task:
115+
p.taskCount.Add(1)
116+
if p.isStopped() {
117+
p.drainAndNackTasks(taskCh)
118+
}
119+
return nil
120+
case <-p.ctx.Done():
121+
return ErrTaskSchedulerClosed
122+
}
123+
}
124+
125+
func (p *weightedRoundRobinTaskPool[K]) TrySubmit(task PriorityTask) (bool, error) {
126+
if p.isStopped() {
127+
return false, ErrTaskSchedulerClosed
128+
}
129+
130+
key := p.options.TaskToChannelKeyFn(task)
131+
weight := p.options.ChannelKeyToWeightFn(key)
132+
taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight)
133+
defer releaseFn()
134+
135+
select {
136+
case taskCh <- task:
137+
p.taskCount.Add(1)
138+
if p.isStopped() {
139+
p.drainAndNackTasks(taskCh)
140+
}
141+
return true, nil
142+
case <-p.ctx.Done():
143+
return false, ErrTaskSchedulerClosed
144+
default:
145+
return false, nil
146+
}
147+
}
148+
149+
func (p *weightedRoundRobinTaskPool[K]) GetNextTask() (PriorityTask, bool) {
150+
if p.isStopped() {
151+
return nil, false
152+
}
153+
154+
p.Lock()
155+
defer p.Unlock()
156+
157+
// Get a fresh schedule if we don't have one or if we've reached the end
158+
if p.schedule == nil || p.scheduleIndex >= len(p.schedule) {
159+
p.schedule = p.pool.GetSchedule()
160+
p.scheduleIndex = 0
161+
162+
if len(p.schedule) == 0 {
163+
return nil, false
164+
}
165+
}
166+
167+
// Try to get a task starting from the current index
168+
startIndex := p.scheduleIndex
169+
for {
170+
select {
171+
case task := <-p.schedule[p.scheduleIndex]:
172+
// Found a task, advance index and return
173+
p.scheduleIndex++
174+
p.taskCount.Add(-1)
175+
return task, true
176+
case <-p.ctx.Done():
177+
return nil, false
178+
default:
179+
// No task in this channel, try next
180+
p.scheduleIndex++
181+
182+
// If we've reached the end, get a fresh schedule and continue
183+
if p.scheduleIndex >= len(p.schedule) {
184+
p.schedule = p.pool.GetSchedule()
185+
p.scheduleIndex = 0
186+
187+
if len(p.schedule) == 0 {
188+
return nil, false
189+
}
190+
}
191+
192+
// If we've made a full loop without finding a task, return false
193+
if p.scheduleIndex == startIndex {
194+
return nil, false
195+
}
196+
}
197+
}
198+
}
199+
200+
func (p *weightedRoundRobinTaskPool[K]) Len() int {
201+
return int(p.taskCount.Load())
202+
}
203+
204+
func (p *weightedRoundRobinTaskPool[K]) drainAndNackTasks(taskCh <-chan PriorityTask) {
205+
for {
206+
select {
207+
case task := <-taskCh:
208+
p.taskCount.Add(-1)
209+
task.Nack(nil)
210+
default:
211+
return
212+
}
213+
}
214+
}
215+
216+
func (p *weightedRoundRobinTaskPool[K]) isStopped() bool {
217+
return atomic.LoadInt32(&p.status) == common.DaemonStatusStopped
218+
}

0 commit comments

Comments
 (0)