Skip to content

Commit 530b431

Browse files
committed
Prototypes propagation with multiple values. Adds MultiTextMapCarrier, extending TextMapCarrier.
Gives example extracting requests with multiple 'baggage' headers set.
1 parent d428313 commit 530b431

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed

propagation/baggage.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ func (b Baggage) Inject(ctx context.Context, carrier TextMapCarrier) {
2929

3030
// Extract returns a copy of parent with the baggage from the carrier added.
3131
func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context.Context {
32+
multiCarrier, isMultiCarrier := carrier.(MultiTextMapCarrier)
33+
if isMultiCarrier {
34+
return extractMultiBaggage(parent, multiCarrier)
35+
}
36+
return extractSingleBaggage(parent, carrier)
37+
}
38+
39+
// Fields returns the keys who's values are set with Inject.
40+
func (b Baggage) Fields() []string {
41+
return []string{baggageHeader}
42+
}
43+
44+
func extractSingleBaggage(parent context.Context, carrier TextMapCarrier) context.Context {
3245
bStr := carrier.Get(baggageHeader)
3346
if bStr == "" {
3447
return parent
@@ -41,7 +54,20 @@ func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context
4154
return baggage.ContextWithBaggage(parent, bag)
4255
}
4356

44-
// Fields returns the keys who's values are set with Inject.
45-
func (b Baggage) Fields() []string {
46-
return []string{baggageHeader}
57+
func extractMultiBaggage(parent context.Context, carrier MultiTextMapCarrier) context.Context {
58+
bVals := carrier.GetAll(baggageHeader)
59+
members := make([]baggage.Member, 0)
60+
for _, bStr := range bVals {
61+
currBag, err := baggage.Parse(bStr)
62+
if err != nil {
63+
continue
64+
}
65+
members = append(members, currBag.Members()...)
66+
}
67+
68+
b, err := baggage.New(members...)
69+
if err != nil || b.Len() == 0 {
70+
return parent
71+
}
72+
return baggage.ContextWithBaggage(parent, b)
4773
}

propagation/baggage_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,55 @@ func TestExtractValidBaggageFromHTTPReq(t *testing.T) {
128128
}
129129
}
130130

131+
func TestExtractValidMultipleBaggageHeaders(t *testing.T) {
132+
prop := propagation.TextMapPropagator(propagation.Baggage{})
133+
tests := []struct {
134+
name string
135+
headers []string
136+
want members
137+
}{
138+
{
139+
name: "non conflicting headers",
140+
headers: []string{"key1=val1", "key2=val2"},
141+
want: members{
142+
{Key: "key1", Value: "val1"},
143+
{Key: "key2", Value: "val2"},
144+
},
145+
},
146+
{
147+
name: "conflicting keys, uses last val",
148+
headers: []string{"key1=val1", "key1=val2"},
149+
want: members{
150+
{Key: "key1", Value: "val2"},
151+
},
152+
},
153+
{
154+
name: "single empty",
155+
headers: []string{"", "key1=val1"},
156+
want: members{
157+
{Key: "key1", Value: "val1"},
158+
},
159+
},
160+
{
161+
name: "all empty",
162+
headers: []string{"", ""},
163+
want: members{},
164+
},
165+
}
166+
167+
for _, tt := range tests {
168+
t.Run(tt.name, func(t *testing.T) {
169+
req, _ := http.NewRequest("GET", "http://example.com", nil)
170+
req.Header["Baggage"] = tt.headers
171+
172+
ctx := context.Background()
173+
ctx = prop.Extract(ctx, propagation.HeaderCarrier(req.Header))
174+
expected := tt.want.Baggage(t)
175+
assert.Equal(t, expected, baggage.FromContext(ctx))
176+
})
177+
}
178+
}
179+
131180
func TestExtractInvalidDistributedContextFromHTTPReq(t *testing.T) {
132181
prop := propagation.TextMapPropagator(propagation.Baggage{})
133182
tests := []struct {

propagation/propagation.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ type TextMapCarrier interface {
2929
// must never be done outside of a new major release.
3030
}
3131

32+
// MultiTextMapCarrier is a TextMapCarrier that can return multiple values for a single key.
33+
type MultiTextMapCarrier interface {
34+
TextMapCarrier
35+
// GetAll returns all values associated with the passed key.
36+
GetAll(key string) []string
37+
// DO NOT CHANGE: any modification will not be backwards compatible and
38+
// must never be done outside of a new major release.
39+
}
40+
3241
// MapCarrier is a TextMapCarrier that uses a map held in memory as a storage
3342
// medium for propagated key-value pairs.
3443
type MapCarrier map[string]string
@@ -58,11 +67,16 @@ func (c MapCarrier) Keys() []string {
5867
// HeaderCarrier adapts http.Header to satisfy the TextMapCarrier interface.
5968
type HeaderCarrier http.Header
6069

61-
// Get returns the value associated with the passed key.
70+
// Get returns the first value associated with the passed key.
6271
func (hc HeaderCarrier) Get(key string) string {
6372
return http.Header(hc).Get(key)
6473
}
6574

75+
// GetAll returns all values associated with the passed key.
76+
func (hc HeaderCarrier) GetAll(key string) []string {
77+
return http.Header(hc).Values(key)
78+
}
79+
6680
// Set stores the key-value pair.
6781
func (hc HeaderCarrier) Set(key string, value string) {
6882
http.Header(hc).Set(key, value)

0 commit comments

Comments
 (0)