Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions stage/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func (s *Stage) Run(ctx context.Context) int {
// This initial size is just a good start, might not be enough.
results := make([]*QueryResult, 0, len(s.Queries)+len(s.QueryFiles))
s.States.resultChan = make(chan *QueryResult, 16)
// This dummyQueryResult is used to push into resultChan once we know all the query results are done recording.
var dummyQueryResult = &QueryResult{
StageId: "__EOF__SENTINEL__",
}
timeToExit := make(chan os.Signal, 1)
signal.Notify(timeToExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
// Each goroutine we spawn will increment this wait group (count-down latch). We may start a goroutine for running
Expand All @@ -150,13 +154,10 @@ func (s *Stage) Run(ctx context.Context) int {

go func() {
s.States.wgExitMainStage.Wait()
// wgExitMainStage goes down to 0 after all the goroutines finish. Then we exit the driver by
// closing the timeToExit channel, which will trigger the graceful shutdown process -
// (flushing the log file, writing the final time log summary, etc.).

// When SIGKILL and SIGINT are captured, we trigger this process by canceling the context, which will cause
// "context cancelled" errors in goroutines to let them exit.
close(timeToExit)
// wgExitMainStage goes down to 0 after all the goroutines finish. Then we will send dummyQueryResult
// to resultChan to symbolize all query results are done recording.
s.States.resultChan <- dummyQueryResult;
}()

ctx, s.States.AbortAll = context.WithCancelCause(ctx)
Expand All @@ -175,21 +176,27 @@ func (s *Stage) Run(ctx context.Context) int {
for {
select {
case result := <-s.States.resultChan:
// result == dummyQueryResult: all results received, finalize and exit
if result == dummyQueryResult {
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
}
// Explicitly close(timeToExit) here to trigger the the graceful shutdown process -
// (flushing the log file, writing the final time log summary, etc.).
close(timeToExit)
return int(s.States.exitCode.Load())
}
results = append(results, result)
for _, recorder := range s.States.runRecorders {
recorder.RecordQuery(utils.GetCtxWithTimeout(time.Second*5), s, result)
}

// When SIGKILL and SIGINT are captured, we trigger this process by canceling the context, which will cause
// "context cancelled" errors in goroutines to let them exit.
case sig := <-timeToExit:
if sig != nil {
// Cancel the context and wait for the goroutines to exit.
s.States.AbortAll(fmt.Errorf(sig.String()))
continue
}
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
}
return int(s.States.exitCode.Load())
// Cancel the context.
s.States.AbortAll(fmt.Errorf("%s", sig.String()))
}
}
}
Expand Down