diff --git a/internal/internal_coroutines_test.go b/internal/internal_coroutines_test.go index bee4c0a86..809904d1b 100644 --- a/internal/internal_coroutines_test.go +++ b/internal/internal_coroutines_test.go @@ -31,6 +31,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" ) func createRootTestContext(t *testing.T) (ctx Context) { @@ -38,6 +43,16 @@ func createRootTestContext(t *testing.T) (ctx Context) { interceptors, envInterceptor := newWorkflowInterceptors(env.impl, env.impl.workflowInterceptors) return newWorkflowContext(env.impl, interceptors, envInterceptor) } +func createRootTestContextWithLogger(t *testing.T, logger *zap.Logger) (ctx Context) { + s := WorkflowTestSuite{} + s.SetLogger(logger) + // tally is not set since metrics are not noisy by default, and the test-instance + // is largely useless without access to the instance for snapshots. + env := s.NewTestWorkflowEnvironment() + env.Test(t) + interceptors, envInterceptor := newWorkflowInterceptors(env.impl, env.impl.workflowInterceptors) + return newWorkflowContext(env.impl, interceptors, envInterceptor) +} func TestDispatcher(t *testing.T) { value := "foo" @@ -664,6 +679,7 @@ func TestDispatchClose(t *testing.T) { } func TestPanic(t *testing.T) { + defer goleak.VerifyNone(t) var history []string d, _ := newDispatcher(createRootTestContext(t), func(ctx Context) { c := NewNamedChannel(ctx, "forever_blocked") @@ -680,6 +696,8 @@ func TestPanic(t *testing.T) { history = append(history, "root") c.Receive(ctx, nil) // blocked forever }) + defer d.Close() // stop other coroutines, as only one panicked + require.EqualValues(t, 0, len(history)) err := d.ExecuteUntilAllBlocked() require.Error(t, err) @@ -1203,3 +1221,41 @@ func TestChainedFuture(t *testing.T) { require.NoError(t, env.GetWorkflowResult(&out)) require.Equal(t, 5, out) } + +func TestShutdownTimeout(t *testing.T) { + unblock := make(chan struct{}) + // custom test context so logs can be checked + core, obs := observer.New(zap.InfoLevel) + logger := zap.New(core, zap.WrapCore(func(zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, zaptest.NewLogger(t).Core()) // tee to test logs + })) + + ctx := createRootTestContextWithLogger(t, logger) + d, _ := newDispatcher(ctx, func(ctx Context) { + defer func() { + <-unblock // block until timeout, then unblock + }() + c := NewChannel(ctx) + c.Receive(ctx, nil) // block forever + }) + + before := time.Now() + require.NoError(t, d.ExecuteUntilAllBlocked()) + go func() { + <-time.After(2 * time.Second) // current timeout is 1s + close(unblock) + }() + d.Close() // should hang until close(unblock) allows it to continue + blocked := time.Since(before) + + // make sure it didn't give up after the internal 1s timeout, + // as this likely implies parallel coroutine closing, which risks crashes + // due to racy memory corruption. + assert.Greater(t, blocked, 2*time.Second, "d.Close should have waited for unblock to occur, but it returned too early") + + logs := obs.FilterMessageSnippet("workflow failed to shut down").All() + assert.Len(t, logs, 1, "expected only one log") + if len(logs) > 0 { // check so it doesn't panic if wrong + assert.Contains(t, logs[0].ContextMap()["stacks"], t.Name(), "stacktrace field should mention this test func") + } +} diff --git a/internal/internal_workflow.go b/internal/internal_workflow.go index c7c127e00..6e77213f1 100644 --- a/internal/internal_workflow.go +++ b/internal/internal_workflow.go @@ -154,6 +154,7 @@ type ( unblock chan unblockFunc // used to notify coroutine that it should continue executing. keptBlocked bool // true indicates that coroutine didn't make any progress since the last yield unblocking closed bool // indicates that owning coroutine has finished execution + stopped chan struct{} // close(stopped) when finished closing blocked atomic.Bool panicError *workflowPanicError // non nil if coroutine had unhandled panic } @@ -166,6 +167,9 @@ type ( executing bool // currently running ExecuteUntilAllBlocked. Used to avoid recursive calls to it. mutex sync.Mutex // used to synchronize executing closed bool + // callback to report if stopping the dispatcher times out (>1s per coroutine). + // nearly every case should take less than a millisecond, so this is only for fairly significant mistakes. + shutdownTimeout func(idx int, stacks string) } // The current timeout resolution implementation is in seconds and uses math.Ceil() as the duration. But is @@ -533,7 +537,23 @@ func (d *syncWorkflowDefinition) Close() { // Context passed to the root function is child of the passed rootCtx. // This way rootCtx can be used to pass values to the coroutine code. func newDispatcher(rootCtx Context, root func(ctx Context)) (*dispatcherImpl, Context) { - result := &dispatcherImpl{} + env := getWorkflowEnvironment(rootCtx) + met := env.GetMetricsScope() + log := env.GetLogger() + result := &dispatcherImpl{ + shutdownTimeout: func(idx int, stacks string) { + // dispatcher/coroutine shutdown should be nearly instant. + // this is called if it is not. + met.Counter("cadence-workflow-shutdown-timeout").Inc(1) + log.Error( + "workflow failed to shut down within ~1s. "+ + "generally this means a significant problem with user code, "+ + "e.g. mutexes, time.Sleep, or infinite loops in defers", + zap.String("stacks", stacks), + zap.Int("coroutine", idx), + ) + }, + } ctxWithState := result.newCoroutine(rootCtx, root) return result, ctxWithState } @@ -807,6 +827,14 @@ func (s *coroutineState) initialYield(stackDepth int, status string) { // yield indicates that coroutine cannot make progress and should sleep // this call blocks func (s *coroutineState) yield(status string) { + // make sure we're the running coroutine before writing to the aboutToBlock + // channel, as it's not safe to undo that write if the caller used the + // wrong context to make a blocking call. + if s.blocked.Load() { + // same as initialYield + panic("trying to block on coroutine which is already blocked, most likely a wrong Context is used to do blocking" + + " call (like Future.Get() or Channel.Receive()") + } s.aboutToBlock <- true s.initialYield(3, status) // omit three levels of stack. To adjust change to 0 and count the lines to remove. s.keptBlocked = true @@ -849,6 +877,10 @@ func (s *coroutineState) close() { s.aboutToBlock <- true } +func (s *coroutineState) wait() <-chan struct{} { + return s.stopped +} + func (s *coroutineState) exit() { if !s.closed { s.unblock <- func(status string, stackDepth int) bool { @@ -878,6 +910,7 @@ func (d *dispatcherImpl) newNamedCoroutine(ctx Context, name string, f func(ctx state := d.newState(name) spawned := WithValue(ctx, coroutinesContextKey, state) go func(crt *coroutineState) { + defer close(crt.stopped) defer crt.close() defer func() { if r := recover(); r != nil { @@ -897,6 +930,7 @@ func (d *dispatcherImpl) newState(name string) *coroutineState { dispatcher: d, aboutToBlock: make(chan bool, 1), unblock: make(chan unblockFunc), + stopped: make(chan struct{}), } d.sequence++ d.coroutines = append(d.coroutines, c) @@ -963,10 +997,35 @@ func (d *dispatcherImpl) Close() { } d.closed = true d.mutex.Unlock() + // collect a stacktrace before stopping things, so it can be reported if it + // does not stop cleanly (because it's too late by that point). + stacktrace := d.StackTrace() + t := time.NewTimer(time.Second) + reported := false for i := 0; i < len(d.coroutines); i++ { c := d.coroutines[i] if !c.closed { c.exit() + select { + case <-c.wait(): + // clean shutdown + t.Stop() + case <-t.C: + // timeout, emit a warning log and metric. + // this only needs to be done once because we report all stacks + // the first time. + if !reported { + d.shutdownTimeout(i, stacktrace) + reported = true + } + // and continue waiting, it's not safe to ignore it and cause + // other goroutines to exit concurrently. + // + // since this is a full second after shutdown began, there's a + // good chance this will never finish, and will become a leaked + // goroutine. these can at least be seen in pprof. + <-c.wait() + } } } }