diff --git a/stage/stage.go b/stage/stage.go index f0ec519..3df3c8d 100644 --- a/stage/stage.go +++ b/stage/stage.go @@ -139,6 +139,10 @@ func (s *Stage) Run(ctx context.Context) int { // This initial size is just a good start, might not be enough. results := make([]*QueryResult, 0, len(s.Queries)+len(s.QueryFiles)) s.States.resultChan = make(chan *QueryResult, 16) + // This dummyQueryResult is used to push into resultChan once we know all the query results are done recording. + var dummyQueryResult = &QueryResult{ + StageId: "__EOF__SENTINEL__", + } timeToExit := make(chan os.Signal, 1) signal.Notify(timeToExit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) // Each goroutine we spawn will increment this wait group (count-down latch). We may start a goroutine for running @@ -150,13 +154,10 @@ func (s *Stage) Run(ctx context.Context) int { go func() { s.States.wgExitMainStage.Wait() - // wgExitMainStage goes down to 0 after all the goroutines finish. Then we exit the driver by - // closing the timeToExit channel, which will trigger the graceful shutdown process - - // (flushing the log file, writing the final time log summary, etc.). - // When SIGKILL and SIGINT are captured, we trigger this process by canceling the context, which will cause - // "context cancelled" errors in goroutines to let them exit. - close(timeToExit) + // wgExitMainStage goes down to 0 after all the goroutines finish. Then we will send dummyQueryResult + // to resultChan to symbolize all query results are done recording. + s.States.resultChan <- dummyQueryResult; }() ctx, s.States.AbortAll = context.WithCancelCause(ctx) @@ -175,21 +176,27 @@ func (s *Stage) Run(ctx context.Context) int { for { select { case result := <-s.States.resultChan: + // result == dummyQueryResult: all results received, finalize and exit + if result == dummyQueryResult { + s.States.RunFinishTime = time.Now() + for _, recorder := range s.States.runRecorders { + recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results) + } + // Explicitly close(timeToExit) here to trigger the the graceful shutdown process - + // (flushing the log file, writing the final time log summary, etc.). + close(timeToExit) + return int(s.States.exitCode.Load()) + } results = append(results, result) for _, recorder := range s.States.runRecorders { recorder.RecordQuery(utils.GetCtxWithTimeout(time.Second*5), s, result) } + + // When SIGKILL and SIGINT are captured, we trigger this process by canceling the context, which will cause + // "context cancelled" errors in goroutines to let them exit. case sig := <-timeToExit: - if sig != nil { - // Cancel the context and wait for the goroutines to exit. - s.States.AbortAll(fmt.Errorf(sig.String())) - continue - } - s.States.RunFinishTime = time.Now() - for _, recorder := range s.States.runRecorders { - recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results) - } - return int(s.States.exitCode.Load()) + // Cancel the context. + s.States.AbortAll(fmt.Errorf("%s", sig.String())) } } }