diff --git a/common/task/interface.go b/common/task/interface.go index 820d6236774..db855e96e84 100644 --- a/common/task/interface.go +++ b/common/task/interface.go @@ -40,6 +40,25 @@ type ( TrySubmit(task PriorityTask) (bool, error) } + // TaskPool manages task storage and determines scheduling order. + // Different implementations provide different scheduling algorithms. + TaskPool interface { + common.Daemon + + // Submit adds a task to the pool, blocks if pool is full + Submit(task PriorityTask) error + + // TrySubmit attempts to add a task, returns immediately if pool is full + TrySubmit(task PriorityTask) (bool, error) + + // GetNextTask retrieves the next task according to the pool's scheduling algorithm + // Returns (task, true) if a task is available, (nil, false) if no task is ready + GetNextTask() (PriorityTask, bool) + + // Len returns the number of tasks currently in the pool + Len() int + } + // SchedulerType respresents the type of the task scheduler implementation SchedulerType int diff --git a/common/task/interface_mock.go b/common/task/interface_mock.go index 3dee2913af9..598f0181e95 100644 --- a/common/task/interface_mock.go +++ b/common/task/interface_mock.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: interface.go +// Source: common/task/interface.go // // Generated by this command: // -// mockgen -package task -source interface.go -destination interface_mock.go -self_package github.com/uber/cadence/common/task +// mockgen -package task -source common/task/interface.go -destination common/task/interface_mock.go -self_package github.com/uber/cadence/common/task // // Package task is a generated GoMock package. @@ -154,6 +154,112 @@ func (mr *MockSchedulerMockRecorder) TrySubmit(task any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrySubmit", reflect.TypeOf((*MockScheduler)(nil).TrySubmit), task) } +// MockTaskPool is a mock of TaskPool interface. +type MockTaskPool struct { + ctrl *gomock.Controller + recorder *MockTaskPoolMockRecorder + isgomock struct{} +} + +// MockTaskPoolMockRecorder is the mock recorder for MockTaskPool. +type MockTaskPoolMockRecorder struct { + mock *MockTaskPool +} + +// NewMockTaskPool creates a new mock instance. +func NewMockTaskPool(ctrl *gomock.Controller) *MockTaskPool { + mock := &MockTaskPool{ctrl: ctrl} + mock.recorder = &MockTaskPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTaskPool) EXPECT() *MockTaskPoolMockRecorder { + return m.recorder +} + +// GetNextTask mocks base method. +func (m *MockTaskPool) GetNextTask() (PriorityTask, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNextTask") + ret0, _ := ret[0].(PriorityTask) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetNextTask indicates an expected call of GetNextTask. +func (mr *MockTaskPoolMockRecorder) GetNextTask() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNextTask", reflect.TypeOf((*MockTaskPool)(nil).GetNextTask)) +} + +// Len mocks base method. +func (m *MockTaskPool) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockTaskPoolMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockTaskPool)(nil).Len)) +} + +// Start mocks base method. +func (m *MockTaskPool) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start. +func (mr *MockTaskPoolMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTaskPool)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockTaskPool) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockTaskPoolMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockTaskPool)(nil).Stop)) +} + +// Submit mocks base method. +func (m *MockTaskPool) Submit(task PriorityTask) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Submit", task) + ret0, _ := ret[0].(error) + return ret0 +} + +// Submit indicates an expected call of Submit. +func (mr *MockTaskPoolMockRecorder) Submit(task any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Submit", reflect.TypeOf((*MockTaskPool)(nil).Submit), task) +} + +// TrySubmit mocks base method. +func (m *MockTaskPool) TrySubmit(task PriorityTask) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TrySubmit", task) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TrySubmit indicates an expected call of TrySubmit. +func (mr *MockTaskPoolMockRecorder) TrySubmit(task any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TrySubmit", reflect.TypeOf((*MockTaskPool)(nil).TrySubmit), task) +} + // MockTask is a mock of Task interface. type MockTask struct { ctrl *gomock.Controller diff --git a/common/task/weighted_channel_pool.go b/common/task/weighted_channel_pool.go index 243e65a4cf9..bb38f41f0d1 100644 --- a/common/task/weighted_channel_pool.go +++ b/common/task/weighted_channel_pool.go @@ -83,7 +83,7 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any]( timeSource clock.TimeSource, options WeightedRoundRobinChannelPoolOptions, ) *WeightedRoundRobinChannelPool[K, V] { - return &WeightedRoundRobinChannelPool[K, V]{ + wrr := &WeightedRoundRobinChannelPool[K, V]{ bufferSize: options.BufferSize, idleChannelTTLInSeconds: options.IdleChannelTTLInSeconds, logger: logger, @@ -92,6 +92,8 @@ func NewWeightedRoundRobinChannelPool[K comparable, V any]( channelMap: make(map[K]*weightedChannel[V]), shutdownCh: make(chan struct{}), } + wrr.iwrrSchedule.Store(make([]chan V, 0)) + return wrr } func (p *WeightedRoundRobinChannelPool[K, V]) Start() { diff --git a/common/task/weighted_round_robin_task_pool.go b/common/task/weighted_round_robin_task_pool.go new file mode 100644 index 00000000000..fa4a7b94978 --- /dev/null +++ b/common/task/weighted_round_robin_task_pool.go @@ -0,0 +1,222 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package task + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/metrics" +) + +type weightedRoundRobinTaskPool[K comparable] struct { + sync.Mutex + status int32 + pool *WeightedRoundRobinChannelPool[K, PriorityTask] + ctx context.Context + cancel context.CancelFunc + options *WeightedRoundRobinTaskPoolOptions[K] + logger log.Logger + taskCount atomic.Int64 // O(1) task count tracking + schedule []chan PriorityTask // Current schedule + scheduleIndex int // Current position in schedule +} + +// NewWeightedRoundRobinTaskPool creates a new WRR task pool +func NewWeightedRoundRobinTaskPool[K comparable]( + logger log.Logger, + metricsClient metrics.Client, + timeSource clock.TimeSource, + options *WeightedRoundRobinTaskPoolOptions[K], +) TaskPool { + metricsScope := metricsClient.Scope(metrics.TaskSchedulerScope) + ctx, cancel := context.WithCancel(context.Background()) + + pool := &weightedRoundRobinTaskPool[K]{ + status: common.DaemonStatusInitialized, + pool: NewWeightedRoundRobinChannelPool[K, PriorityTask]( + logger, + metricsScope, + timeSource, + WeightedRoundRobinChannelPoolOptions{ + BufferSize: options.QueueSize, + IdleChannelTTLInSeconds: defaultIdleChannelTTLInSeconds, + }), + ctx: ctx, + cancel: cancel, + options: options, + logger: logger, + } + + return pool +} + +func (p *weightedRoundRobinTaskPool[K]) Start() { + if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) { + return + } + + p.pool.Start() + + p.logger.Info("Weighted round robin task pool started.") +} + +func (p *weightedRoundRobinTaskPool[K]) Stop() { + if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusStarted, common.DaemonStatusStopped) { + return + } + + p.cancel() + p.pool.Stop() + + // Drain all tasks and nack them, updating the counter + taskChs := p.pool.GetAllChannels() + for _, taskCh := range taskChs { + p.drainAndNackTasks(taskCh) + } + + p.logger.Info("Weighted round robin task pool stopped.") +} + +func (p *weightedRoundRobinTaskPool[K]) Submit(task PriorityTask) error { + if p.isStopped() { + return ErrTaskSchedulerClosed + } + + key := p.options.TaskToChannelKeyFn(task) + weight := p.options.ChannelKeyToWeightFn(key) + + // Reject tasks with weight 0 - they won't be in the schedule + if weight <= 0 { + return ErrTaskWeightZero + } + + taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight) + defer releaseFn() + + select { + case taskCh <- task: + p.taskCount.Add(1) + if p.isStopped() { + p.drainAndNackTasks(taskCh) + } + return nil + case <-p.ctx.Done(): + return ErrTaskSchedulerClosed + } +} + +func (p *weightedRoundRobinTaskPool[K]) TrySubmit(task PriorityTask) (bool, error) { + if p.isStopped() { + return false, ErrTaskSchedulerClosed + } + + key := p.options.TaskToChannelKeyFn(task) + weight := p.options.ChannelKeyToWeightFn(key) + + // Reject tasks with weight 0 - they won't be in the schedule + if weight <= 0 { + return false, ErrTaskWeightZero + } + + taskCh, releaseFn := p.pool.GetOrCreateChannel(key, weight) + defer releaseFn() + + select { + case taskCh <- task: + p.taskCount.Add(1) + if p.isStopped() { + p.drainAndNackTasks(taskCh) + } + return true, nil + case <-p.ctx.Done(): + return false, ErrTaskSchedulerClosed + default: + return false, nil + } +} + +func (p *weightedRoundRobinTaskPool[K]) GetNextTask() (PriorityTask, bool) { + if p.isStopped() || p.Len() == 0 { + return nil, false + } + + p.Lock() + defer p.Unlock() + + // Get a fresh schedule if we don't have one or if we've reached the end + if p.schedule == nil || p.scheduleIndex >= len(p.schedule) { + p.schedule = p.pool.GetSchedule() + p.scheduleIndex = 0 + if len(p.schedule) == 0 { + return nil, false + } + } + + for p.Len() > 0 { + select { + case task := <-p.schedule[p.scheduleIndex]: + // Found a task, advance index and return + p.scheduleIndex++ + p.taskCount.Add(-1) + return task, true + case <-p.ctx.Done(): + return nil, false + default: + // No task in this channel, try next + p.scheduleIndex++ + + // If we've reached the end, get a fresh schedule and continue + if p.scheduleIndex >= len(p.schedule) { + p.schedule = p.pool.GetSchedule() + p.scheduleIndex = 0 + if len(p.schedule) == 0 { + return nil, false + } + } + } + } + return nil, false +} + +func (p *weightedRoundRobinTaskPool[K]) Len() int { + return int(p.taskCount.Load()) +} + +func (p *weightedRoundRobinTaskPool[K]) drainAndNackTasks(taskCh <-chan PriorityTask) { + for { + select { + case task := <-taskCh: + p.taskCount.Add(-1) + task.Nack(nil) + default: + return + } + } +} + +func (p *weightedRoundRobinTaskPool[K]) isStopped() bool { + return atomic.LoadInt32(&p.status) == common.DaemonStatusStopped +} diff --git a/common/task/weighted_round_robin_task_pool_options.go b/common/task/weighted_round_robin_task_pool_options.go new file mode 100644 index 00000000000..5e475a82636 --- /dev/null +++ b/common/task/weighted_round_robin_task_pool_options.go @@ -0,0 +1,36 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package task + +import ( + "fmt" +) + +// WeightedRoundRobinTaskPoolOptions configs WRR task pool +type WeightedRoundRobinTaskPoolOptions[K comparable] struct { + QueueSize int + TaskToChannelKeyFn func(PriorityTask) K + ChannelKeyToWeightFn func(K) int +} + +func (o *WeightedRoundRobinTaskPoolOptions[K]) String() string { + return fmt.Sprintf("{QueueSize: %v}", o.QueueSize) +} diff --git a/common/task/weighted_round_robin_task_pool_test.go b/common/task/weighted_round_robin_task_pool_test.go new file mode 100644 index 00000000000..4c8d75b9b74 --- /dev/null +++ b/common/task/weighted_round_robin_task_pool_test.go @@ -0,0 +1,602 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package task + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber-go/tally" + "go.uber.org/mock/gomock" + + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/metrics" +) + +func TestWeightedRoundRobinTaskPool_Submit(t *testing.T) { + tests := []struct { + name string + queueSize int + numTasks int + wantErr bool + errContains string + }{ + { + name: "submit single task", + queueSize: 10, + numTasks: 1, + wantErr: false, + }, + { + name: "submit multiple tasks", + queueSize: 10, + numTasks: 5, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: tt.queueSize, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + pool.Start() + defer pool.Stop() + + for i := 0; i < tt.numTasks; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(1).AnyTimes() + task.EXPECT().SetPriority(1).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() // Called during shutdown + + err := pool.Submit(task) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + } + }) + } +} + +func TestWeightedRoundRobinTaskPool_TrySubmit(t *testing.T) { + tests := []struct { + name string + queueSize int + numTasks int + wantSuccess []bool + }{ + { + name: "try submit with available space", + queueSize: 10, + numTasks: 3, + wantSuccess: []bool{true, true, true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: tt.queueSize, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + pool.Start() + defer pool.Stop() + + for i := 0; i < tt.numTasks; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(1).AnyTimes() + task.EXPECT().SetPriority(1).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() // Called during shutdown + + ok, err := pool.TrySubmit(task) + require.NoError(t, err) + assert.Equal(t, tt.wantSuccess[i], ok) + } + }) + } +} + +func TestWeightedRoundRobinTaskPool_GetNextTask(t *testing.T) { + tests := []struct { + name string + submitPrios []int + expectTasks bool + expectedCount int + }{ + { + name: "get tasks from non-empty pool", + submitPrios: []int{1, 2, 3}, + expectTasks: true, + expectedCount: 3, + }, + { + name: "get task from empty pool", + submitPrios: []int{}, + expectTasks: false, + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 10, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + pool.Start() + defer pool.Stop() + + // Submit tasks + for _, prio := range tt.submitPrios { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(prio).AnyTimes() + task.EXPECT().SetPriority(prio).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() // Called during shutdown + err := pool.Submit(task) + require.NoError(t, err) + } + + // Get tasks + count := 0 + for i := 0; i < len(tt.submitPrios)+1; i++ { + task, ok := pool.GetNextTask() + if ok { + assert.NotNil(t, task) + count++ + } else { + break + } + } + + assert.Equal(t, tt.expectedCount, count) + }) + } +} + +func TestWeightedRoundRobinTaskPool_WeightedScheduling(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 100, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + pool.Start() + defer pool.Stop() + + // Submit tasks with different priorities + // Priority 3 should appear 3 times more than priority 1 + taskCounts := map[int]int{ + 1: 10, + 2: 10, + 3: 10, + } + + for prio, count := range taskCounts { + for i := 0; i < count; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(prio).AnyTimes() + task.EXPECT().SetPriority(prio).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() // Called during shutdown or get + err := pool.Submit(task) + require.NoError(t, err) + } + } + + // Retrieve all tasks and count by priority + retrievedCounts := map[int]int{} + for { + task, ok := pool.GetNextTask() + if !ok { + break + } + prio := task.Priority() + retrievedCounts[prio]++ + } + + // Verify all tasks were retrieved + assert.Equal(t, taskCounts[1], retrievedCounts[1]) + assert.Equal(t, taskCounts[2], retrievedCounts[2]) + assert.Equal(t, taskCounts[3], retrievedCounts[3]) +} + +func TestWeightedRoundRobinTaskPool_Lifecycle(t *testing.T) { + tests := []struct { + name string + startPool bool + stopPool bool + submitAfterStop bool + expectError bool + }{ + { + name: "normal lifecycle", + startPool: true, + stopPool: true, + submitAfterStop: false, + expectError: false, + }, + { + name: "submit after stop", + startPool: true, + stopPool: true, + submitAfterStop: true, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 10, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + + if tt.startPool { + pool.Start() + } + + if tt.stopPool { + pool.Stop() + } + + if tt.submitAfterStop { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(1).AnyTimes() + + err := pool.Submit(task) + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, ErrTaskSchedulerClosed, err) + } else { + assert.NoError(t, err) + } + } + }) + } +} + +func TestWeightedRoundRobinTaskPool_Len(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 100, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + return key + }, + }, + ) + pool.Start() + defer pool.Stop() + + // Initially empty + assert.Equal(t, 0, pool.Len()) + + // Submit some tasks + for i := 0; i < 10; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(1).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() + err := pool.Submit(task) + require.NoError(t, err) + } + + // Should have 10 tasks + assert.Equal(t, 10, pool.Len()) + + // Get some tasks + for i := 0; i < 5; i++ { + task, ok := pool.GetNextTask() + assert.True(t, ok) + assert.NotNil(t, task) + } + + // Should have 5 tasks remaining + assert.Equal(t, 5, pool.Len()) + + // Get remaining tasks + for i := 0; i < 5; i++ { + task, ok := pool.GetNextTask() + assert.True(t, ok) + assert.NotNil(t, task) + } + + // Should be empty + assert.Equal(t, 0, pool.Len()) +} + +func TestWeightedRoundRobinTaskPool_WeightedOrderingWithMultipleKeys(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + // Setup: 3 priorities with weights [3, 2, 1] + // Priority 0 -> weight 3 + // Priority 1 -> weight 2 + // Priority 2 -> weight 1 + // Expected IWRR schedule: [0, 0, 1, 0, 1, 2] (interleaved weighted round-robin) + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 100, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + weights := map[int]int{ + 0: 3, + 1: 2, + 2: 1, + } + return weights[key] + }, + }, + ) + pool.Start() + defer pool.Stop() + + // Submit 2 complete cycles worth of tasks (12 tasks total) + // 6 tasks for priority 0, 4 tasks for priority 1, 2 tasks for priority 2 + tasksPerPriority := map[int]int{ + 0: 6, // weight 3, appears 3 times per cycle, 2 cycles = 6 tasks + 1: 4, // weight 2, appears 2 times per cycle, 2 cycles = 4 tasks + 2: 2, // weight 1, appears 1 time per cycle, 2 cycles = 2 tasks + } + + // Create and submit tasks + for priority, count := range tasksPerPriority { + for i := 0; i < count; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(priority).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() + err := pool.Submit(task) + require.NoError(t, err) + } + } + + // Verify total count + totalTasks := tasksPerPriority[0] + tasksPerPriority[1] + tasksPerPriority[2] + assert.Equal(t, totalTasks, pool.Len()) + + // Retrieve all tasks and track their priorities + retrievedPriorities := []int{} + for i := 0; i < totalTasks; i++ { + task, ok := pool.GetNextTask() + require.True(t, ok, "should get task %d of %d", i+1, totalTasks) + require.NotNil(t, task) + retrievedPriorities = append(retrievedPriorities, task.Priority()) + } + + // Verify the weighted round-robin pattern + // Expected pattern repeats: [0, 0, 1, 0, 1, 2] + // Cycle 1: [0, 0, 1, 0, 1, 2] - consumes 3 from p0, 2 from p1, 1 from p2 + // Cycle 2: [0, 0, 1, 0, 1, 2] - consumes 3 from p0, 2 from p1, 1 from p2 + expectedPattern := []int{ + 0, 0, 1, 0, 1, 2, // First complete cycle + 0, 0, 1, 0, 1, 2, // Second complete cycle + } + + assert.Equal(t, expectedPattern, retrievedPriorities, + "retrieved tasks should follow weighted round-robin pattern [0,0,1,0,1,2] repeated twice") + + // Verify pool is now empty + task, ok := pool.GetNextTask() + assert.False(t, ok, "pool should be empty after retrieving all tasks") + assert.Nil(t, task) + assert.Equal(t, 0, pool.Len()) +} + +func TestWeightedRoundRobinTaskPool_WeightedOrderingAcrossMultipleCalls(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + logger := testlogger.New(t) + metricsClient := metrics.NewClient(tally.NoopScope, metrics.Common, metrics.HistogramMigration{}) + timeSource := clock.NewRealTimeSource() + + // Setup with different weights: [5, 3, 1] + pool := NewWeightedRoundRobinTaskPool[int]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[int]{ + QueueSize: 100, + TaskToChannelKeyFn: func(task PriorityTask) int { + return task.Priority() + }, + ChannelKeyToWeightFn: func(key int) int { + weights := map[int]int{ + 0: 5, + 1: 3, + 2: 1, + } + return weights[key] + }, + }, + ) + pool.Start() + defer pool.Stop() + + // Submit enough tasks for one complete schedule cycle + // Schedule length = 5 + 3 + 1 = 9 + // IWRR schedule: [0, 0, 1, 0, 1, 0, 1, 0, 2] + tasksPerPriority := map[int]int{ + 0: 5, + 1: 3, + 2: 1, + } + + for priority, count := range tasksPerPriority { + for i := 0; i < count; i++ { + task := NewMockPriorityTask(ctrl) + task.EXPECT().Priority().Return(priority).AnyTimes() + task.EXPECT().Nack(gomock.Any()).AnyTimes() + err := pool.Submit(task) + require.NoError(t, err) + } + } + + // Retrieve tasks in small batches to verify state is maintained across calls + retrievedPriorities := []int{} + + // Batch 1: Get 3 tasks + for i := 0; i < 3; i++ { + task, ok := pool.GetNextTask() + require.True(t, ok) + retrievedPriorities = append(retrievedPriorities, task.Priority()) + } + + // Batch 2: Get 4 tasks + for i := 0; i < 4; i++ { + task, ok := pool.GetNextTask() + require.True(t, ok) + retrievedPriorities = append(retrievedPriorities, task.Priority()) + } + + // Batch 3: Get remaining 2 tasks + for i := 0; i < 2; i++ { + task, ok := pool.GetNextTask() + require.True(t, ok) + retrievedPriorities = append(retrievedPriorities, task.Priority()) + } + + // Expected IWRR pattern for weights [5, 3, 1] + // The interleaved weighted round-robin algorithm produces: + // Round 4: [0] (only weight 5 > 4) + // Round 3: [0] (only weight 5 > 3) + // Round 2: [0, 1] (weights 5,3 > 2) + // Round 1: [0, 1] (weights 5,3 > 1) + // Round 0: [0, 1, 2] (weights 5,3,1 > 0) + // Result: [0, 0, 0, 1, 0, 1, 0, 1, 2] + expectedPattern := []int{0, 0, 0, 1, 0, 1, 0, 1, 2} + + assert.Equal(t, expectedPattern, retrievedPriorities, + "should maintain WRR order across multiple GetNextTask calls") + + // Verify empty + assert.Equal(t, 0, pool.Len()) +} diff --git a/common/task/weighted_round_robin_task_scheduler.go b/common/task/weighted_round_robin_task_scheduler.go index 91ac630bacd..706cda48d08 100644 --- a/common/task/weighted_round_robin_task_scheduler.go +++ b/common/task/weighted_round_robin_task_scheduler.go @@ -38,7 +38,7 @@ type weightedRoundRobinTaskSchedulerImpl[K comparable] struct { sync.RWMutex status int32 - pool *WeightedRoundRobinChannelPool[K, PriorityTask] + taskPool TaskPool ctx context.Context cancel context.CancelFunc notifyCh chan struct{} @@ -58,6 +58,8 @@ const ( var ( // ErrTaskSchedulerClosed is the error returned when submitting task to a stopped scheduler ErrTaskSchedulerClosed = errors.New("task scheduler has already shutdown") + // ErrTaskWeightZero is the error returned when submitting task with weight 0 + ErrTaskWeightZero = errors.New("task weight must be greater than zero") ) // NewWeightedRoundRobinTaskScheduler creates a new WRR task scheduler @@ -70,16 +72,21 @@ func NewWeightedRoundRobinTaskScheduler[K comparable]( ) (Scheduler, error) { metricsScope := metricsClient.Scope(metrics.TaskSchedulerScope) ctx, cancel := context.WithCancel(context.Background()) + + taskPool := NewWeightedRoundRobinTaskPool[K]( + logger, + metricsClient, + timeSource, + &WeightedRoundRobinTaskPoolOptions[K]{ + QueueSize: options.QueueSize, + TaskToChannelKeyFn: options.TaskToChannelKeyFn, + ChannelKeyToWeightFn: options.ChannelKeyToWeightFn, + }, + ) + scheduler := &weightedRoundRobinTaskSchedulerImpl[K]{ - status: common.DaemonStatusInitialized, - pool: NewWeightedRoundRobinChannelPool[K, PriorityTask]( - logger, - metricsScope, - timeSource, - WeightedRoundRobinChannelPoolOptions{ - BufferSize: options.QueueSize, - IdleChannelTTLInSeconds: defaultIdleChannelTTLInSeconds, - }), + status: common.DaemonStatusInitialized, + taskPool: taskPool, ctx: ctx, cancel: cancel, notifyCh: make(chan struct{}, 1), @@ -97,6 +104,8 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) Start() { return } + w.taskPool.Start() + w.dispatcherWG.Add(w.options.DispatcherCount) for i := 0; i != w.options.DispatcherCount; i++ { go w.dispatcher() @@ -110,11 +119,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) Stop() { } w.cancel() - - taskChs := w.pool.GetAllChannels() - for _, taskCh := range taskChs { - drainAndNackPriorityTask(taskCh) - } + w.taskPool.Stop() if success := common.AwaitWaitGroup(&w.dispatcherWG, time.Minute); !success { w.logger.Warn("Weighted round robin task scheduler timedout on shutdown.") @@ -128,52 +133,22 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) Submit(task PriorityTask) error sw := w.metricsScope.StartTimer(metrics.PriorityTaskSubmitLatency) defer sw.Stop() - if w.isStopped() { - return ErrTaskSchedulerClosed - } - - key := w.options.TaskToChannelKeyFn(task) - weight := w.options.ChannelKeyToWeightFn(key) - taskCh, releaseFn := w.pool.GetOrCreateChannel(key, weight) - defer releaseFn() - select { - case taskCh <- task: + err := w.taskPool.Submit(task) + if err == nil { w.notifyDispatcher() - if w.isStopped() { - drainAndNackPriorityTask(taskCh) - } - return nil - case <-w.ctx.Done(): - return ErrTaskSchedulerClosed } + return err } func (w *weightedRoundRobinTaskSchedulerImpl[K]) TrySubmit( task PriorityTask, ) (bool, error) { - if w.isStopped() { - return false, ErrTaskSchedulerClosed - } - - key := w.options.TaskToChannelKeyFn(task) - weight := w.options.ChannelKeyToWeightFn(key) - taskCh, releaseFn := w.pool.GetOrCreateChannel(key, weight) - defer releaseFn() - - select { - case taskCh <- task: + submitted, err := w.taskPool.TrySubmit(task) + if submitted { w.metricsScope.IncCounter(metrics.PriorityTaskSubmitRequest) - if w.isStopped() { - drainAndNackPriorityTask(taskCh) - } else { - w.notifyDispatcher() - } - return true, nil - case <-w.ctx.Done(): - return false, ErrTaskSchedulerClosed - default: - return false, nil + w.notifyDispatcher() } + return submitted, err } func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatcher() { @@ -190,22 +165,22 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatcher() { } func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatchTasks() { - hasTask := true - for hasTask { - hasTask = false - schedule := w.pool.GetSchedule() - for _, taskCh := range schedule { - select { - case task := <-taskCh: - hasTask = true - if err := w.processor.Submit(task); err != nil { - w.logger.Error("fail to submit task to processor", tag.Error(err)) - task.Nack(err) - } - case <-w.ctx.Done(): - return - default: - } + for w.taskPool.Len() > 0 { + select { + case <-w.ctx.Done(): + return + default: + } + + task, ok := w.taskPool.GetNextTask() + if !ok { + // No more tasks available + return + } + + if err := w.processor.Submit(task); err != nil { + w.logger.Error("fail to submit task to processor", tag.Error(err)) + task.Nack(err) } } } diff --git a/common/task/weighted_round_robin_task_scheduler_test.go b/common/task/weighted_round_robin_task_scheduler_test.go index 3e27de4fc72..a6d7e952173 100644 --- a/common/task/weighted_round_robin_task_scheduler_test.go +++ b/common/task/weighted_round_robin_task_scheduler_test.go @@ -117,12 +117,14 @@ func (s *weightedRoundRobinTaskSchedulerSuite) TestSubmit_Success() { err := s.scheduler.Submit(mockTask) s.NoError(err) + // Cast taskPool to concrete type to access internal pool for testing + pool := s.scheduler.taskPool.(*weightedRoundRobinTaskPool[int]).pool weight := s.scheduler.options.ChannelKeyToWeightFn(taskPriority) - taskCh, releaseFn := s.scheduler.pool.GetOrCreateChannel(taskPriority, weight) + taskCh, releaseFn := pool.GetOrCreateChannel(taskPriority, weight) defer releaseFn() task := <-taskCh s.Equal(mockTask, task) - taskChs := s.scheduler.pool.GetAllChannels() + taskChs := pool.GetAllChannels() for _, taskCh := range taskChs { s.Empty(taskCh) } @@ -188,7 +190,9 @@ func (s *weightedRoundRobinTaskSchedulerSuite) TestDispatcher_SubmitWithNoError( if expectedRemainingTasksNum < 0 { expectedRemainingTasksNum = 0 } - taskCh, releaseFn := s.scheduler.pool.GetOrCreateChannel(priority, weight) + // Cast taskPool to concrete type to access internal pool for testing + pool := s.scheduler.taskPool.(*weightedRoundRobinTaskPool[int]).pool + taskCh, releaseFn := pool.GetOrCreateChannel(priority, weight) s.Equal(expectedRemainingTasksNum, len(taskCh)) releaseFn() } @@ -218,6 +222,9 @@ func (s *weightedRoundRobinTaskSchedulerSuite) TestDispatcher_SubmitWithNoError( close(doneCh) }() + // Manually trigger notification since tasks were submitted before dispatcher started + s.scheduler.notifyDispatcher() + taskWG.Wait() s.scheduler.cancel()