Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions common/task/fifo_task_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package task

import (
"context"
"sync"
"sync/atomic"
"time"
Expand All @@ -38,7 +39,8 @@ type fifoTaskSchedulerImpl struct {
options *FIFOTaskSchedulerOptions
dispatcherWG sync.WaitGroup
taskCh chan PriorityTask
shutdownCh chan struct{}
ctx context.Context
cancel context.CancelFunc

processor Processor
}
Expand All @@ -52,13 +54,15 @@ func NewFIFOTaskScheduler(
metricsClient metrics.Client,
options *FIFOTaskSchedulerOptions,
) Scheduler {
ctx, cancel := context.WithCancel(context.Background())
return &fifoTaskSchedulerImpl{
status: common.DaemonStatusInitialized,
logger: logger,
metricsScope: metricsClient.Scope(metrics.TaskSchedulerScope),
options: options,
taskCh: make(chan PriorityTask, options.QueueSize),
shutdownCh: make(chan struct{}),
ctx: ctx,
cancel: cancel,
processor: NewParallelTaskProcessor(
logger,
metricsClient,
Expand Down Expand Up @@ -91,7 +95,7 @@ func (f *fifoTaskSchedulerImpl) Stop() {
return
}

close(f.shutdownCh)
f.cancel()

f.processor.Stop()

Expand Down Expand Up @@ -119,7 +123,7 @@ func (f *fifoTaskSchedulerImpl) Submit(task PriorityTask) error {
f.drainAndNackTasks()
}
return nil
case <-f.shutdownCh:
case <-f.ctx.Done():
return ErrTaskSchedulerClosed
}
}
Expand All @@ -136,7 +140,7 @@ func (f *fifoTaskSchedulerImpl) TrySubmit(task PriorityTask) (bool, error) {
f.drainAndNackTasks()
}
return true, nil
case <-f.shutdownCh:
case <-f.ctx.Done():
return false, ErrTaskSchedulerClosed
default:
return false, nil
Expand All @@ -153,7 +157,7 @@ func (f *fifoTaskSchedulerImpl) dispatcher() {
f.logger.Error("failed to submit task to processor", tag.Error(err))
task.Nack(err)
}
case <-f.shutdownCh:
case <-f.ctx.Done():
return
}
}
Expand Down
18 changes: 11 additions & 7 deletions common/task/weighted_round_robin_task_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package task

import (
"context"
"errors"
"sync"
"sync/atomic"
Expand All @@ -38,7 +39,8 @@ type weightedRoundRobinTaskSchedulerImpl[K comparable] struct {

status int32
pool *WeightedRoundRobinChannelPool[K, PriorityTask]
shutdownCh chan struct{}
ctx context.Context
cancel context.CancelFunc
notifyCh chan struct{}
dispatcherWG sync.WaitGroup
logger log.Logger
Expand Down Expand Up @@ -67,6 +69,7 @@ func NewWeightedRoundRobinTaskScheduler[K comparable](
options *WeightedRoundRobinTaskSchedulerOptions[K],
) (Scheduler, error) {
metricsScope := metricsClient.Scope(metrics.TaskSchedulerScope)
ctx, cancel := context.WithCancel(context.Background())
scheduler := &weightedRoundRobinTaskSchedulerImpl[K]{
status: common.DaemonStatusInitialized,
pool: NewWeightedRoundRobinChannelPool[K, PriorityTask](
Expand All @@ -77,7 +80,8 @@ func NewWeightedRoundRobinTaskScheduler[K comparable](
BufferSize: options.QueueSize,
IdleChannelTTLInSeconds: defaultIdleChannelTTLInSeconds,
}),
shutdownCh: make(chan struct{}),
ctx: ctx,
cancel: cancel,
notifyCh: make(chan struct{}, 1),
logger: logger,
metricsScope: metricsScope,
Expand Down Expand Up @@ -105,7 +109,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) Stop() {
return
}

close(w.shutdownCh)
w.cancel()

taskChs := w.pool.GetAllChannels()
for _, taskCh := range taskChs {
Expand Down Expand Up @@ -139,7 +143,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) Submit(task PriorityTask) error
drainAndNackPriorityTask(taskCh)
}
return nil
case <-w.shutdownCh:
case <-w.ctx.Done():
return ErrTaskSchedulerClosed
}
}
Expand All @@ -165,7 +169,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) TrySubmit(
w.notifyDispatcher()
}
return true, nil
case <-w.shutdownCh:
case <-w.ctx.Done():
return false, ErrTaskSchedulerClosed
default:
return false, nil
Expand All @@ -179,7 +183,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatcher() {
select {
case <-w.notifyCh:
w.dispatchTasks()
case <-w.shutdownCh:
case <-w.ctx.Done():
return
}
}
Expand All @@ -198,7 +202,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatchTasks() {
w.logger.Error("fail to submit task to processor", tag.Error(err))
task.Nack(err)
}
case <-w.shutdownCh:
case <-w.ctx.Done():
return
default:
}
Expand Down
8 changes: 4 additions & 4 deletions common/task/weighted_round_robin_task_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (s *weightedRoundRobinTaskSchedulerSuite) TestDispatcher_SubmitWithNoError(
}()

taskWG.Wait()
close(s.scheduler.shutdownCh)
s.scheduler.cancel()

<-doneCh
}
Expand Down Expand Up @@ -248,7 +248,7 @@ func (s *weightedRoundRobinTaskSchedulerSuite) TestDispatcher_FailToSubmit() {
}()

taskWG.Wait()
close(s.scheduler.shutdownCh)
s.scheduler.cancel()

<-doneCh
}
Expand Down Expand Up @@ -372,9 +372,9 @@ func testSchedulerContract(
s.True(common.AwaitWaitGroup(&taskWG, 10*time.Second))
switch schedulerImpl := scheduler.(type) {
case *fifoTaskSchedulerImpl:
<-schedulerImpl.shutdownCh
<-schedulerImpl.ctx.Done()
case *weightedRoundRobinTaskSchedulerImpl[int]:
<-schedulerImpl.shutdownCh
<-schedulerImpl.ctx.Done()
default:
s.Fail("unknown task scheduler type")
}
Expand Down