diff --git a/internal/internal_task_handlers_test.go b/internal/internal_task_handlers_test.go index 63b813ed9..07da0ccb7 100644 --- a/internal/internal_task_handlers_test.go +++ b/internal/internal_task_handlers_test.go @@ -1897,3 +1897,189 @@ func Test_IsSearchAttributesMatched(t *testing.T) { }) } } + +func Test__GetWorkflowStartedEvent(t *testing.T) { + wfStartedEvent := createTestEventWorkflowExecutionStarted(1, &s.WorkflowExecutionStartedEventAttributes{TaskList: &s.TaskList{Name: common.StringPtr("tl1")}}) + h := &history{workflowTask: &workflowTask{task: &s.PollForDecisionTaskResponse{History: &s.History{Events: []*s.HistoryEvent{wfStartedEvent}}}}} + result, err := h.GetWorkflowStartedEvent() + require.NoError(t, err) + require.Equal(t, wfStartedEvent, result) + + emptyHistory := &history{workflowTask: &workflowTask{task: &s.PollForDecisionTaskResponse{History: &s.History{}}}} + result, err = emptyHistory.GetWorkflowStartedEvent() + require.ErrorContains(t, err, "unable to find WorkflowExecutionStartedEventAttributes") + require.Nil(t, result) +} + +func Test__verifyAllEventsProcessed(t *testing.T) { + testCases := []struct { + name string + lastEventID int64 + nextEventID int64 + Message string + }{ + { + name: "error", + lastEventID: 1, + nextEventID: 1, + Message: "history_events: premature end of stream", + }, + { + name: "warn", + lastEventID: 1, + nextEventID: 3, + Message: "history_events: processed events past the expected lastEventID", + }, + { + name: "success", + lastEventID: 1, + nextEventID: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + obs, logs := observer.New(zap.WarnLevel) + logger := zap.New(obs) + h := &history{ + lastEventID: testCase.lastEventID, + nextEventID: testCase.nextEventID, + eventsHandler: &workflowExecutionEventHandlerImpl{workflowEnvironmentImpl: &workflowEnvironmentImpl{logger: logger}}} + err := h.verifyAllEventsProcessed() + if testCase.name == "error" { + require.ErrorContains(t, err, testCase.Message) + } else if testCase.name == "warn" { + warnLogs := logs.FilterMessage(testCase.Message) + require.Len(t, warnLogs.All(), 1) + } else { + require.NoError(t, err) + } + }) + } + +} + +func Test__workflowCategorizedByTimeout(t *testing.T) { + testCases := []struct { + timeout int32 + expectedCategory string + }{ + { + timeout: 1, + expectedCategory: "instant", + }, + { + timeout: 1000, + expectedCategory: "short", + }, + { + timeout: 2000, + expectedCategory: "intermediate", + }, + { + timeout: 30000, + expectedCategory: "long", + }, + } + for _, tt := range testCases { + t.Run(tt.expectedCategory, func(t *testing.T) { + wfContext := &workflowExecutionContextImpl{workflowInfo: &WorkflowInfo{ExecutionStartToCloseTimeoutSeconds: tt.timeout}} + require.Equal(t, tt.expectedCategory, workflowCategorizedByTimeout(wfContext)) + }) + } +} + +func Test__SignalWorkflow(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockService := workflowservicetest.NewMockClient(mockCtrl) + mockService.EXPECT().SignalWorkflowExecution(gomock.Any(), gomock.Any(), callOptions()...).Return(nil) + cadenceInvoker := &cadenceInvoker{ + identity: "Test_Cadence_Invoker", + service: mockService, + taskToken: nil, + } + err := cadenceInvoker.SignalWorkflow(context.Background(), "test-domain", "test-workflow-id", "test-run-id", "test-signal-name", nil) + require.NoError(t, err) +} + +func Test__getRetryBackoffWithNowTime(t *testing.T) { + now := time.Now() + testCases := []struct { + name string + maxAttempts int32 + ExpInterval time.Duration + result time.Duration + attempt int32 + errReason string + expireTime time.Time + initialInterval time.Duration + maxInterval time.Duration + }{ + { + name: "no max attempts or expiration interval set", + maxAttempts: 0, + ExpInterval: 0, + result: noRetryBackoff, + }, + { + name: "max attempts done", + maxAttempts: 5, + attempt: 5, + result: noRetryBackoff, + }, + { + name: "non retryable error", + maxAttempts: 5, + attempt: 2, + errReason: "bad request", + initialInterval: time.Minute, + maxInterval: time.Minute, + result: noRetryBackoff, + }, + { + name: "fallback to max interval when calculated backoff is 0", + maxAttempts: 5, + attempt: 2, + initialInterval: 0, + maxInterval: time.Minute, + result: time.Minute, + }, + { + name: "fallback to no retry backoff when calculated backoff is 0 and max interval is not set", + maxAttempts: 5, + attempt: 2, + initialInterval: 0, + result: noRetryBackoff, + }, + { + name: "expiry time reached", + maxAttempts: 5, + attempt: 2, + expireTime: now.Add(time.Second), + initialInterval: time.Minute, + maxInterval: time.Minute, + result: noRetryBackoff, + }, + { + name: "retry after backoff", + maxAttempts: 5, + attempt: 2, + errReason: "timeout", + initialInterval: time.Minute, + maxInterval: time.Minute, + result: time.Minute, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + policy := &RetryPolicy{ + MaximumAttempts: tt.maxAttempts, + ExpirationInterval: tt.ExpInterval, + BackoffCoefficient: 2, + NonRetriableErrorReasons: []string{"bad request"}, + MaximumInterval: tt.maxInterval, + InitialInterval: tt.initialInterval, + } + require.Equal(t, tt.result, getRetryBackoffWithNowTime(policy, tt.attempt, tt.errReason, now, tt.expireTime)) + }) + } +}