Skip to content

Commit 815dd18

Browse files
authored
adding context as first parameter to all methods on our client (#239)
1 parent ab0bf9b commit 815dd18

9 files changed

+91
-83
lines changed

client.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package cadence
2222

2323
import (
24+
"context"
2425
"time"
2526

2627
"github.com/uber-go/tally"
@@ -40,14 +41,14 @@ type (
4041
// StartWorkflow starts a workflow execution
4142
// The user can use this to start using a function or workflow type name.
4243
// Either by
43-
// StartWorkflow(options, "workflowTypeName", input)
44+
// StartWorkflow(ctx, options, "workflowTypeName", input)
4445
// or
45-
// StartWorkflow(options, workflowExecuteFn, arg1, arg2, arg3)
46+
// StartWorkflow(ctx, options, workflowExecuteFn, arg1, arg2, arg3)
4647
// The errors it can return:
4748
// - EntityNotExistsError
4849
// - BadRequestError
4950
// - WorkflowExecutionAlreadyStartedError
50-
StartWorkflow(options StartWorkflowOptions, workflow interface{}, args ...interface{}) (*WorkflowExecution, error)
51+
StartWorkflow(ctx context.Context, options StartWorkflowOptions, workflow interface{}, args ...interface{}) (*WorkflowExecution, error)
5152

5253
// SignalWorkflow sends a signals to a workflow in execution
5354
// - workflow ID of the workflow.
@@ -56,7 +57,7 @@ type (
5657
// The errors it can return:
5758
// - EntityNotExistsError
5859
// - InternalServiceError
59-
SignalWorkflow(workflowID string, runID string, signalName string, arg interface{}) error
60+
SignalWorkflow(ctx context.Context, workflowID string, runID string, signalName string, arg interface{}) error
6061

6162
// CancelWorkflow cancels a workflow in execution
6263
// - workflow ID of the workflow.
@@ -65,7 +66,7 @@ type (
6566
// - EntityNotExistsError
6667
// - BadRequestError
6768
// - InternalServiceError
68-
CancelWorkflow(workflowID string, runID string) error
69+
CancelWorkflow(ctx context.Context, workflowID string, runID string) error
6970

7071
// TerminateWorkflow terminates a workflow execution.
7172
// workflowID is required, other parameters are optional.
@@ -75,7 +76,7 @@ type (
7576
// - EntityNotExistsError
7677
// - BadRequestError
7778
// - InternalServiceError
78-
TerminateWorkflow(workflowID string, runID string, reason string, details []byte) error
79+
TerminateWorkflow(ctx context.Context, workflowID string, runID string, reason string, details []byte) error
7980

8081
// GetWorkflowHistory gets history of a particular workflow.
8182
// - workflow ID of the workflow.
@@ -84,7 +85,7 @@ type (
8485
// - EntityNotExistsError
8586
// - BadRequestError
8687
// - InternalServiceError
87-
GetWorkflowHistory(workflowID string, runID string) (*s.History, error)
88+
GetWorkflowHistory(ctx context.Context, workflowID string, runID string) (*s.History, error)
8889

8990
// GetWorkflowStackTrace gets a stack trace of all goroutines of a particular workflow.
9091
// atDecisionTaskCompletedEventID is the eventID of the CompleteDecisionTask event at which stack trace should be taken.
@@ -94,7 +95,7 @@ type (
9495
// - EntityNotExistsError
9596
// - BadRequestError
9697
// - InternalServiceError
97-
GetWorkflowStackTrace(workflowID string, runID string, atDecisionTaskCompletedEventID int64) (string, error)
98+
GetWorkflowStackTrace(ctx context.Context, workflowID string, runID string, atDecisionTaskCompletedEventID int64) (string, error)
9899

99100
// CompleteActivity reports activity completed.
100101
// activity Execute method can return cadence.ErrActivityResultPending to
@@ -109,28 +110,28 @@ type (
109110
// To fail the activity with an error.
110111
// CompleteActivity(token, nil, NewErrorWithDetails("reason", details)
111112
// The activity can fail with below errors ErrorWithDetails, TimeoutError, CanceledError.
112-
CompleteActivity(taskToken []byte, result interface{}, err error) error
113+
CompleteActivity(ctx context.Context, taskToken []byte, result interface{}, err error) error
113114

114115
// RecordActivityHeartbeat records heartbeat for an activity.
115116
// details - is the progress you want to record along with heart beat for this activity.
116117
// The errors it can return:
117118
// - EntityNotExistsError
118119
// - InternalServiceError
119-
RecordActivityHeartbeat(taskToken []byte, details ...interface{}) error
120+
RecordActivityHeartbeat(ctx context.Context, taskToken []byte, details ...interface{}) error
120121

121122
// ListClosedWorkflow gets closed workflow executions based on request filters
122123
// The errors it can return:
123124
// - BadRequestError
124125
// - InternalServiceError
125126
// - EntityNotExistError
126-
ListClosedWorkflow(request *s.ListClosedWorkflowExecutionsRequest) (*s.ListClosedWorkflowExecutionsResponse, error)
127+
ListClosedWorkflow(ctx context.Context, request *s.ListClosedWorkflowExecutionsRequest) (*s.ListClosedWorkflowExecutionsResponse, error)
127128

128129
// ListClosedWorkflow gets open workflow executions based on request filters
129130
// The errors it can return:
130131
// - BadRequestError
131132
// - InternalServiceError
132133
// - EntityNotExistError
133-
ListOpenWorkflow(request *s.ListOpenWorkflowExecutionsRequest) (*s.ListOpenWorkflowExecutionsResponse, error)
134+
ListOpenWorkflow(ctx context.Context, request *s.ListOpenWorkflowExecutionsRequest) (*s.ListOpenWorkflowExecutionsResponse, error)
134135

135136
// QueryWorkflow queries a given workflow execution and returns the query result synchronously. Parameter workflowID
136137
// and queryType are required, other parameters are optional. The workflowID and runID (optional) identify the
@@ -150,7 +151,7 @@ type (
150151
// - InternalServiceError
151152
// - EntityNotExistError
152153
// - QueryFailError
153-
QueryWorkflow(workflowID string, runID string, queryType string, args ...interface{}) (EncodedValue, error)
154+
QueryWorkflow(ctx context.Context, workflowID string, runID string, queryType string, args ...interface{}) (EncodedValue, error)
154155
}
155156

156157
// ClientOptions are optional parameters for Client creation.
@@ -191,7 +192,7 @@ type (
191192
// - DomainAlreadyExistsError
192193
// - BadRequestError
193194
// - InternalServiceError
194-
Register(request *s.RegisterDomainRequest) error
195+
Register(ctx context.Context, request *s.RegisterDomainRequest) error
195196

196197
// Describe a domain. The domain has two part of information.
197198
// DomainInfo - Which has Name, Status, Description, Owner Email.
@@ -200,7 +201,7 @@ type (
200201
// - EntityNotExistsError
201202
// - BadRequestError
202203
// - InternalServiceError
203-
Describe(name string) (*s.DomainInfo, *s.DomainConfiguration, error)
204+
Describe(ctx context.Context, name string) (*s.DomainInfo, *s.DomainConfiguration, error)
204205

205206
// Update a domain. The domain has two part of information.
206207
// UpdateDomainInfo - To update domain Description and Owner Email.
@@ -209,7 +210,7 @@ type (
209210
// - EntityNotExistsError
210211
// - BadRequestError
211212
// - InternalServiceError
212-
Update(name string, domainInfo *s.UpdateDomainInfo, domainConfig *s.DomainConfiguration) error
213+
Update(ctx context.Context, name string, domainInfo *s.UpdateDomainInfo, domainConfig *s.DomainConfiguration) error
213214
}
214215
)
215216

internal_task_handlers.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ func (i *cadenceInvoker) Heartbeat(details []byte) error {
904904

905905
func (i *cadenceInvoker) internalHeartBeat(details []byte) (bool, error) {
906906
isActivityCancelled := false
907-
err := recordActivityHeartbeat(i.service, i.identity, i.taskToken, details, i.retryPolicy)
907+
err := recordActivityHeartbeat(context.Background(), i.service, i.identity, i.taskToken, details, i.retryPolicy)
908908

909909
switch err.(type) {
910910
case *CanceledError:
@@ -1030,6 +1030,7 @@ func createNewDecision(decisionType s.DecisionType) *s.Decision {
10301030
}
10311031

10321032
func recordActivityHeartbeat(
1033+
ctx context.Context,
10331034
service m.TChanWorkflowService,
10341035
identity string,
10351036
taskToken, details []byte,
@@ -1043,11 +1044,11 @@ func recordActivityHeartbeat(
10431044
var heartbeatResponse *s.RecordActivityTaskHeartbeatResponse
10441045
heartbeatErr := backoff.Retry(
10451046
func() error {
1046-
ctx, cancel := newTChannelContext()
1047+
tchCtx, cancel := newTChannelContext(ctx)
10471048
defer cancel()
10481049

10491050
var err error
1050-
heartbeatResponse, err = service.RecordActivityTaskHeartbeat(ctx, request)
1051+
heartbeatResponse, err = service.RecordActivityTaskHeartbeat(tchCtx, request)
10511052
return err
10521053
}, retryPolicy, isServiceTransientError)
10531054

internal_task_handlers_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowStackTraceByID() {
565565
domain := "testDomain"
566566
workflowClient := NewClient(service, domain, nil)
567567

568-
dump, err := workflowClient.GetWorkflowStackTrace("id1", "runId1", 0)
568+
dump, err := workflowClient.GetWorkflowStackTrace(context.Background(), "id1", "runId1", 0)
569569
t.NoError(err)
570570
t.NotNil(dump)
571571
t.True(strings.Contains(dump, ".Receive]"))
@@ -643,7 +643,7 @@ func (t *TaskHandlersTestSuite) TestGetWorkflowStackTraceByIDAndDecisionTaskComp
643643
domain := "testDomain"
644644
workflowClient := NewClient(service, domain, nil)
645645

646-
dump, err := workflowClient.GetWorkflowStackTrace("id1", "runId1", 5)
646+
dump, err := workflowClient.GetWorkflowStackTrace(context.Background(), "id1", "runId1", 5)
647647
t.NoError(err)
648648
t.NotNil(dump)
649649
t.True(strings.Contains(dump, ".Receive]"))

internal_task_pollers.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,19 @@ func (wtp *workflowTaskPoller) ProcessTask(task interface{}) error {
165165
// Respond task completion.
166166
err = backoff.Retry(
167167
func() error {
168-
ctx, cancel := newTChannelContext()
168+
tchCtx, cancel := newTChannelContext(context.Background())
169169
defer cancel()
170170
var err1 error
171171
switch request := completedRequest.(type) {
172172
case *s.RespondDecisionTaskCompletedRequest:
173-
err1 = wtp.service.RespondDecisionTaskCompleted(ctx, request)
173+
err1 = wtp.service.RespondDecisionTaskCompleted(tchCtx, request)
174174
if err1 != nil {
175175
traceLog(func() {
176176
wtp.logger.Debug("RespondDecisionTaskCompleted failed.", zap.Error(err1))
177177
})
178178
}
179179
case *s.RespondQueryTaskCompletedRequest:
180-
err1 = wtp.service.RespondQueryTaskCompleted(ctx, request)
180+
err1 = wtp.service.RespondQueryTaskCompleted(tchCtx, request)
181181
if err1 != nil {
182182
traceLog(func() {
183183
wtp.logger.Debug("RespondQueryTaskCompleted failed.", zap.Error(err1))
@@ -217,10 +217,10 @@ func (wtp *workflowTaskPoller) poll() (*workflowTask, error) {
217217
Identity: common.StringPtr(wtp.identity),
218218
}
219219

220-
ctx, cancel := newTChannelContext(tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
220+
tchCtx, cancel := newTChannelContext(context.Background(), tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
221221
defer cancel()
222222

223-
response, err := wtp.service.PollForDecisionTask(ctx, request)
223+
response, err := wtp.service.PollForDecisionTask(tchCtx, request)
224224
if err != nil {
225225
if isServiceTransientError(err) {
226226
wtp.metricsScope.Counter(metrics.DecisionPollTransientFailedCounter).Inc(1)
@@ -236,14 +236,15 @@ func (wtp *workflowTaskPoller) poll() (*workflowTask, error) {
236236
}
237237

238238
execution := response.GetWorkflowExecution()
239-
iterator := newGetHistoryPageFunc(wtp.service, wtp.domain, execution, math.MaxInt64, wtp.metricsScope)
239+
iterator := newGetHistoryPageFunc(context.Background(), wtp.service, wtp.domain, execution, math.MaxInt64, wtp.metricsScope)
240240
task := &workflowTask{task: response, getHistoryPageFunc: iterator, pollStartTime: startTime}
241241
wtp.metricsScope.Counter(metrics.DecisionPollSucceedCounter).Inc(1)
242242
wtp.metricsScope.Timer(metrics.DecisionPollLatency).Record(time.Now().Sub(startTime))
243243
return task, nil
244244
}
245245

246246
func newGetHistoryPageFunc(
247+
ctx context.Context,
247248
service m.TChanWorkflowService,
248249
domain string,
249250
execution *s.WorkflowExecution,
@@ -256,11 +257,11 @@ func newGetHistoryPageFunc(
256257
var resp *s.GetWorkflowExecutionHistoryResponse
257258
err := backoff.Retry(
258259
func() error {
259-
ctx, cancel := newTChannelContext()
260+
tchCtx, cancel := newTChannelContext(ctx)
260261
defer cancel()
261262

262263
var err1 error
263-
resp, err1 = service.GetWorkflowExecutionHistory(ctx, &s.GetWorkflowExecutionHistoryRequest{
264+
resp, err1 = service.GetWorkflowExecutionHistory(tchCtx, &s.GetWorkflowExecutionHistoryRequest{
264265
Domain: common.StringPtr(domain),
265266
Execution: execution,
266267
NextPageToken: nextPageToken,
@@ -317,10 +318,10 @@ func (atp *activityTaskPoller) poll() (*activityTask, error) {
317318
Identity: common.StringPtr(atp.identity),
318319
}
319320

320-
ctx, cancel := newTChannelContext(tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
321+
tchCtx, cancel := newTChannelContext(context.Background(), tchanTimeout(pollTaskServiceTimeOut), tchanRetryOption(retryNeverOptions))
321322
defer cancel()
322323

323-
response, err := atp.service.PollForActivityTask(ctx, request)
324+
response, err := atp.service.PollForActivityTask(tchCtx, request)
324325
if err != nil {
325326
if isServiceTransientError(err) {
326327
atp.metricsScope.Counter(metrics.ActivityPollTransientFailedCounter).Inc(1)
@@ -375,7 +376,7 @@ func (atp *activityTaskPoller) ProcessTask(task interface{}) error {
375376
}
376377

377378
responseStartTime := time.Now()
378-
reportErr := reportActivityComplete(atp.service, request, atp.metricsScope)
379+
reportErr := reportActivityComplete(context.Background(), atp.service, request, atp.metricsScope)
379380
if reportErr != nil {
380381
atp.metricsScope.Counter(metrics.ActivityResponseFailedCounter).Inc(1)
381382
traceLog(func() {
@@ -389,30 +390,30 @@ func (atp *activityTaskPoller) ProcessTask(task interface{}) error {
389390
return nil
390391
}
391392

392-
func reportActivityComplete(service m.TChanWorkflowService, request interface{}, metricsScope tally.Scope) error {
393+
func reportActivityComplete(ctx context.Context, service m.TChanWorkflowService, request interface{}, metricsScope tally.Scope) error {
393394
if request == nil {
394395
// nothing to report
395396
return nil
396397
}
397398

398-
ctx, cancel := newTChannelContext()
399+
tchCtx, cancel := newTChannelContext(ctx)
399400
defer cancel()
400401
var reportErr error
401402
switch request := request.(type) {
402403
case *s.RespondActivityTaskCanceledRequest:
403404
reportErr = backoff.Retry(
404405
func() error {
405-
return service.RespondActivityTaskCanceled(ctx, request)
406+
return service.RespondActivityTaskCanceled(tchCtx, request)
406407
}, serviceOperationRetryPolicy, isServiceTransientError)
407408
case *s.RespondActivityTaskFailedRequest:
408409
reportErr = backoff.Retry(
409410
func() error {
410-
return service.RespondActivityTaskFailed(ctx, request)
411+
return service.RespondActivityTaskFailed(tchCtx, request)
411412
}, serviceOperationRetryPolicy, isServiceTransientError)
412413
case *s.RespondActivityTaskCompletedRequest:
413414
reportErr = backoff.Retry(
414415
func() error {
415-
return service.RespondActivityTaskCompleted(ctx, request)
416+
return service.RespondActivityTaskCompleted(tchCtx, request)
416417
}, serviceOperationRetryPolicy, isServiceTransientError)
417418
}
418419
if reportErr == nil {

internal_utils.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ func tchanRetryOption(retryOpt *tchannel.RetryOptions) func(builder *tchannel.Co
6969
}
7070

7171
// newTChannelContext - Get a tchannel context
72-
func newTChannelContext(options ...func(builder *tchannel.ContextBuilder)) (tchannel.ContextWithHeaders, context.CancelFunc) {
72+
func newTChannelContext(ctx context.Context, options ...func(builder *tchannel.ContextBuilder)) (tchannel.ContextWithHeaders, context.CancelFunc) {
7373
builder := tchannel.NewContextBuilder(defaultRPCTimeout)
74+
if ctx != nil {
75+
builder.SetParentContext(ctx)
76+
}
7477
builder.SetRetryOptions(retryDefaultOptions)
7578
builder.AddHeader(versionHeaderName, LibraryVersion)
7679
for _, opt := range options {

internal_worker.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ func ensureRequiredParams(params *workerExecutionParameters) {
164164
func verifyDomainExist(client m.TChanWorkflowService, domain string, logger *zap.Logger) error {
165165

166166
descDomainOp := func() error {
167-
ctx, cancel := newTChannelContext()
167+
tchCtx, cancel := newTChannelContext(context.Background())
168168
defer cancel()
169-
_, err := client.DescribeDomain(ctx, &shared.DescribeDomainRequest{Name: &domain})
169+
_, err := client.DescribeDomain(tchCtx, &shared.DescribeDomainRequest{Name: &domain})
170170
if err != nil {
171171
if _, ok := err.(*shared.EntityNotExistsError); ok {
172172
logger.Error("domain does not exist", zap.String("domain", domain), zap.Error(err))

internal_worker_interfaces_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func (s *InterfacesTestSuite) TestInterface() {
179179
DecisionTaskStartToCloseTimeout: 10 * time.Second,
180180
}
181181
workflowClient := NewClient(service, domain, nil)
182-
wfExecution, err := workflowClient.StartWorkflow(workflowOptions, "workflowType")
182+
wfExecution, err := workflowClient.StartWorkflow(context.Background(), workflowOptions, "workflowType")
183183
s.NoError(err)
184184
fmt.Printf("Started workflow: %v \n", wfExecution)
185185
}

internal_worker_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ func TestCompleteActivity(t *testing.T) {
369369
failedRequest = args.Get(1).(*s.RespondActivityTaskFailedRequest)
370370
})
371371

372-
wfClient.CompleteActivity([]byte("task-token"), nil, nil)
372+
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, nil)
373373
require.NotNil(t, completedRequest)
374374

375-
wfClient.CompleteActivity([]byte("task-token"), nil, NewCanceledError())
375+
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, NewCanceledError())
376376
require.NotNil(t, canceledRequest)
377377

378-
wfClient.CompleteActivity([]byte("task-token"), nil, errors.New(""))
378+
wfClient.CompleteActivity(context.Background(), []byte("task-token"), nil, errors.New(""))
379379
require.NotNil(t, failedRequest)
380380
}
381381

@@ -391,8 +391,8 @@ func TestRecordActivityHeartbeat(t *testing.T) {
391391
heartbeatRequest = args.Get(1).(*s.RecordActivityTaskHeartbeatRequest)
392392
})
393393

394-
wfClient.RecordActivityHeartbeat(nil)
395-
wfClient.RecordActivityHeartbeat(nil, "testStack", "customerObjects", 4)
394+
wfClient.RecordActivityHeartbeat(context.Background(), nil)
395+
wfClient.RecordActivityHeartbeat(context.Background(), nil, "testStack", "customerObjects", 4)
396396
require.NotNil(t, heartbeatRequest)
397397
}
398398

0 commit comments

Comments
 (0)