Skip to content

Commit ccea6a6

Browse files
authored
Add Context Propagation support for Local Activities (#986)
1 parent 61b42f1 commit ccea6a6

File tree

8 files changed

+151
-24
lines changed

8 files changed

+151
-24
lines changed

internal/internal_activity.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ type (
8888
DataConverter DataConverter
8989
Attempt int32
9090
ScheduledTime time.Time
91+
Header *shared.Header
9192
}
9293

9394
// asyncActivityClient for requesting activity execution

internal/internal_event_handlers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ type (
132132
attempt int32 // attempt starting from 0
133133
retryPolicy *RetryPolicy
134134
expireTime time.Time
135+
header *shared.Header
135136
}
136137

137138
localActivityMarkerData struct {
@@ -503,6 +504,7 @@ func newLocalActivityTask(params executeLocalActivityParams, callback laResultHa
503504
callback: callback,
504505
retryPolicy: params.RetryPolicy,
505506
attempt: params.Attempt,
507+
header: params.Header,
506508
}
507509

508510
if params.RetryPolicy != nil && params.RetryPolicy.ExpirationInterval > 0 {

internal/internal_task_pollers.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,19 @@ func (lath *localActivityTaskHandler) executeLocalActivityTask(task *localActivi
482482
attempt: task.attempt,
483483
})
484484

485+
// propagate context information into the local activity activity context from the headers
486+
for _, ctxProp := range lath.contextPropagators {
487+
var err error
488+
if ctx, err = ctxProp.Extract(ctx, NewHeaderReader(task.header)); err != nil {
489+
result = &localActivityResult{
490+
task: task,
491+
result: nil,
492+
err: fmt.Errorf("unable to propagate context %v", err),
493+
}
494+
return result
495+
}
496+
}
497+
485498
// panic handler
486499
defer func() {
487500
if p := recover(); p != nil {

internal/workflow.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,13 +518,13 @@ func ExecuteLocalActivity(ctx Context, activity interface{}, args ...interface{}
518518
}
519519

520520
func (wc *workflowEnvironmentInterceptor) ExecuteLocalActivity(ctx Context, activityType string, args ...interface{}) Future {
521+
header := getHeadersFromContext(ctx)
521522
activityFn := ctx.Value(localActivityFnContextKey)
522523
if activityFn == nil {
523524
panic("ExecuteLocalActivity: Expected context key " + localActivityFnContextKey + " is missing")
524525
}
525526

526527
future, settable := newDecodeFuture(ctx, activityFn)
527-
528528
if err := validateFunctionArgs(activityFn, args, false); err != nil {
529529
settable.Set(nil, err)
530530
return future
@@ -542,6 +542,7 @@ func (wc *workflowEnvironmentInterceptor) ExecuteLocalActivity(ctx Context, acti
542542
WorkflowInfo: GetWorkflowInfo(ctx),
543543
DataConverter: getDataConverterFromWorkflowContext(ctx),
544544
ScheduledTime: Now(ctx), // initial scheduled time
545+
Header: header,
545546
}
546547

547548
Go(ctx, func(ctx Context) {

test/activity_test.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,12 @@ func (a *Activities) fail(_ context.Context) error {
8787
return errFailOnPurpose
8888
}
8989

90-
func (a *Activities) InspectActivityInfo(ctx context.Context, domain, taskList, wfType string) error {
91-
a.append("inspectActivityInfo")
92-
info := activity.GetInfo(ctx)
93-
if info.WorkflowDomain != domain {
94-
return fmt.Errorf("expected domainName %v but got %v", domain, info.WorkflowDomain)
95-
}
96-
if info.WorkflowType == nil || info.WorkflowType.Name != wfType {
97-
return fmt.Errorf("expected workflowType %v but got %v", wfType, info.WorkflowType)
98-
}
99-
if info.TaskList != taskList {
100-
return fmt.Errorf("expected taskList %v but got %v", taskList, info.TaskList)
90+
func (a *Activities) DuplicateStringInContext(ctx context.Context) (string, error) {
91+
originalString := ctx.Value(contextKey(testContextKey))
92+
if originalString == nil {
93+
return "", fmt.Errorf("context did not propagate to activity")
10194
}
102-
return nil
95+
return strings.Repeat(originalString.(string), 2), nil
10396
}
10497

10598
func (a *Activities) append(name string) {
@@ -140,6 +133,21 @@ func (a *Activities) GetMemoAndSearchAttr(_ context.Context, memo, searchAttr st
140133
return memo + ", " + searchAttr, nil
141134
}
142135

136+
func (a *Activities) InspectActivityInfo(ctx context.Context, domain, taskList, wfType string) error {
137+
a.append("inspectActivityInfo")
138+
info := activity.GetInfo(ctx)
139+
if info.WorkflowDomain != domain {
140+
return fmt.Errorf("expected domainName %v but got %v", domain, info.WorkflowDomain)
141+
}
142+
if info.WorkflowType == nil || info.WorkflowType.Name != wfType {
143+
return fmt.Errorf("expected workflowType %v but got %v", wfType, info.WorkflowType)
144+
}
145+
if info.TaskList != taskList {
146+
return fmt.Errorf("expected taskList %v but got %v", taskList, info.TaskList)
147+
}
148+
return nil
149+
}
150+
143151
func (a *Activities) register(worker worker.Worker) {
144152
// Kept to verify backward compatibility of activity registration.
145153
activity.RegisterWithOptions(a, activity.RegisterOptions{DisableAlreadyRegisteredCheck: true})

test/integration_test.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ const (
6161
ctxTimeout = 15 * time.Second
6262
domainName = "integration-test-domain"
6363
domainCacheRefreshInterval = 20 * time.Second
64+
testContextKey = "test-context-key"
6465
)
6566

6667
func TestIntegrationSuite(t *testing.T) {
@@ -97,7 +98,10 @@ func (ts *IntegrationTestSuite) SetupSuite() {
9798
rpcClient, err := newRPCClient(ts.config.ServiceName, ts.config.ServiceAddr)
9899
ts.NoError(err)
99100
ts.rpcClient = rpcClient
100-
ts.libClient = client.NewClient(ts.rpcClient.Interface, domainName, &client.Options{})
101+
ts.libClient = client.NewClient(ts.rpcClient.Interface, domainName,
102+
&client.Options{
103+
ContextPropagators: []workflow.ContextPropagator{NewStringMapPropagator([]string{testContextKey})},
104+
})
101105
ts.registerDomain()
102106
}
103107

@@ -138,12 +142,14 @@ func (ts *IntegrationTestSuite) SetupTest() {
138142
ts.worker = worker.New(ts.rpcClient.Interface, domainName, ts.taskListName, worker.Options{
139143
DisableStickyExecution: ts.config.IsStickyOff,
140144
Logger: logger,
145+
ContextPropagators: []workflow.ContextPropagator{NewStringMapPropagator([]string{testContextKey})},
141146
})
142147
ts.tracer = newtracingInterceptorFactory()
143148
options := worker.Options{
144149
DisableStickyExecution: ts.config.IsStickyOff,
145150
Logger: logger,
146151
WorkflowInterceptorChainFactories: []interceptors.WorkflowInterceptorFactory{ts.tracer},
152+
ContextPropagators: []workflow.ContextPropagator{NewStringMapPropagator([]string{testContextKey})},
147153
}
148154
ts.worker = worker.New(ts.rpcClient.Interface, domainName, ts.taskListName, options)
149155
ts.registerWorkflowsAndActivities(ts.worker)
@@ -405,6 +411,13 @@ func (ts *IntegrationTestSuite) TestActivityCancelRepro() {
405411
ts.EqualValues(expected, ts.activities.invoked())
406412
}
407413

414+
func (ts *IntegrationTestSuite) TestWorkflowWithLocalActivityCtxPropagation() {
415+
var expected string
416+
err := ts.executeWorkflow("test-wf-local-activity-ctx-prop", ts.workflows.WorkflowWithLocalActivityCtxPropagation, &expected)
417+
ts.NoError(err)
418+
ts.EqualValues(expected, "test-data-in-contexttest-data-in-context")
419+
}
420+
408421
func (ts *IntegrationTestSuite) TestLargeQueryResultError() {
409422
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
410423
defer cancel()

test/test_utils.go

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,30 @@
2222
package test
2323

2424
import (
25+
"context"
26+
"fmt"
2527
"os"
2628
"strings"
2729

28-
"go.uber.org/cadence/.gen/go/cadence/workflowserviceclient"
2930
"go.uber.org/yarpc"
3031
"go.uber.org/yarpc/transport/tchannel"
32+
33+
"go.uber.org/cadence/.gen/go/cadence/workflowserviceclient"
34+
"go.uber.org/cadence/workflow"
3135
)
3236

33-
// Config contains the integration test configuration
34-
type Config struct {
35-
ServiceAddr string
36-
ServiceName string
37-
IsStickyOff bool
38-
Debug bool
39-
}
37+
type (
38+
// Config contains the integration test configuration
39+
Config struct {
40+
ServiceAddr string
41+
ServiceName string
42+
IsStickyOff bool
43+
Debug bool
44+
}
45+
46+
// context.WithValue need this type instead of basic type string to avoid lint error
47+
contextKey string
48+
)
4049

4150
func newConfig() Config {
4251
cfg := Config{
@@ -107,3 +116,70 @@ func newRPCClient(
107116
client := workflowserviceclient.New(dispatcher.ClientConfig(serviceName))
108117
return &rpcClient{Interface: client, dispatcher: dispatcher}, nil
109118
}
119+
120+
// stringMapPropagator propagates the list of keys across a workflow,
121+
// interpreting the payloads as strings.
122+
// BORROWED FROM 'internal' PACKAGE TESTS.
123+
type stringMapPropagator struct {
124+
keys map[string]struct{}
125+
}
126+
127+
// NewStringMapPropagator returns a context propagator that propagates a set of
128+
// string key-value pairs across a workflow
129+
func NewStringMapPropagator(keys []string) workflow.ContextPropagator {
130+
keyMap := make(map[string]struct{}, len(keys))
131+
for _, key := range keys {
132+
keyMap[key] = struct{}{}
133+
}
134+
return &stringMapPropagator{keyMap}
135+
}
136+
137+
// Inject injects values from context into headers for propagation
138+
func (s *stringMapPropagator) Inject(ctx context.Context, writer workflow.HeaderWriter) error {
139+
for key := range s.keys {
140+
value, ok := ctx.Value(contextKey(key)).(string)
141+
if !ok {
142+
return fmt.Errorf("unable to extract key from context %v", key)
143+
}
144+
writer.Set(key, []byte(value))
145+
}
146+
return nil
147+
}
148+
149+
// InjectFromWorkflow injects values from context into headers for propagation
150+
func (s *stringMapPropagator) InjectFromWorkflow(ctx workflow.Context, writer workflow.HeaderWriter) error {
151+
for key := range s.keys {
152+
value, ok := ctx.Value(contextKey(key)).(string)
153+
if !ok {
154+
return fmt.Errorf("unable to extract key from context %v", key)
155+
}
156+
writer.Set(key, []byte(value))
157+
}
158+
return nil
159+
}
160+
161+
// Extract extracts values from headers and puts them into context
162+
func (s *stringMapPropagator) Extract(ctx context.Context, reader workflow.HeaderReader) (context.Context, error) {
163+
if err := reader.ForEachKey(func(key string, value []byte) error {
164+
if _, ok := s.keys[key]; ok {
165+
ctx = context.WithValue(ctx, contextKey(key), string(value))
166+
}
167+
return nil
168+
}); err != nil {
169+
return nil, err
170+
}
171+
return ctx, nil
172+
}
173+
174+
// ExtractToWorkflow extracts values from headers and puts them into context
175+
func (s *stringMapPropagator) ExtractToWorkflow(ctx workflow.Context, reader workflow.HeaderReader) (workflow.Context, error) {
176+
if err := reader.ForEachKey(func(key string, value []byte) error {
177+
if _, ok := s.keys[key]; ok {
178+
ctx = workflow.WithValue(ctx, contextKey(key), string(value))
179+
}
180+
return nil
181+
}); err != nil {
182+
return nil, err
183+
}
184+
return ctx, nil
185+
}

test/workflow_test.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,21 @@ func (w *Workflows) InspectLocalActivityInfo(ctx workflow.Context) error {
511511
wfType := info.WorkflowType.Name
512512
taskList := info.TaskListName
513513
ctx = workflow.WithLocalActivityOptions(ctx, w.defaultLocalActivityOptions())
514-
activites := Activities{}
514+
activities := Activities{}
515515
return workflow.ExecuteLocalActivity(
516-
ctx, activites.InspectActivityInfo, domain, taskList, wfType).Get(ctx, nil)
516+
ctx, activities.InspectActivityInfo, domain, taskList, wfType).Get(ctx, nil)
517+
}
518+
519+
func (w *Workflows) WorkflowWithLocalActivityCtxPropagation(ctx workflow.Context) (string, error) {
520+
ctx = workflow.WithLocalActivityOptions(ctx, w.defaultLocalActivityOptions())
521+
ctx = workflow.WithValue(ctx, contextKey(testContextKey), "test-data-in-context")
522+
activities := Activities{}
523+
var result string
524+
err := workflow.ExecuteLocalActivity(ctx, activities.DuplicateStringInContext).Get(ctx, &result)
525+
if err != nil {
526+
return "", err
527+
}
528+
return result, nil
517529
}
518530

519531
func (w *Workflows) register(worker worker.Worker) {
@@ -541,6 +553,7 @@ func (w *Workflows) register(worker worker.Worker) {
541553
worker.RegisterWorkflow(w.LargeQueryResultWorkflow)
542554
worker.RegisterWorkflow(w.RetryTimeoutStableErrorWorkflow)
543555
worker.RegisterWorkflow(w.ConsistentQueryWorkflow)
556+
worker.RegisterWorkflow(w.WorkflowWithLocalActivityCtxPropagation)
544557

545558
}
546559

0 commit comments

Comments
 (0)