Skip to content

Commit 83835dc

Browse files
XSAMyashrsharma44
andauthored
Refactor tracing interceptor (#450)
* Add tracing interceptor * Fix CI * Follow dependencies changes * Fix comments * Avoid redundant else and return early * Update interceptors/tracing/interceptors_test.go Co-authored-by: Yash Sharma <[email protected]> * Remove kv package * Remove keyvalue Co-authored-by: Yash Sharma <[email protected]>
1 parent 72478fa commit 83835dc

File tree

4 files changed

+499
-0
lines changed

4 files changed

+499
-0
lines changed

interceptors/tracing/interceptors.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package tracing
5+
6+
import (
7+
"context"
8+
9+
"google.golang.org/grpc"
10+
11+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
12+
)
13+
14+
type SpanKind string
15+
16+
const (
17+
SpanKindServer SpanKind = "server"
18+
SpanKindClient SpanKind = "client"
19+
)
20+
21+
func reportable(tracer Tracer) interceptors.CommonReportableFunc {
22+
return func(ctx context.Context, c interceptors.CallMeta, isClient bool) (interceptors.Reporter, context.Context) {
23+
kind := SpanKindServer
24+
if isClient {
25+
kind = SpanKindClient
26+
}
27+
28+
newCtx, span := tracer.Start(ctx, c.FullMethod(), kind)
29+
return &reporter{ctx: newCtx, span: span}, newCtx
30+
}
31+
}
32+
33+
// UnaryClientInterceptor returns a new unary client interceptor that optionally traces the execution of external gRPC calls.
34+
// Tracer will use tags (from tags package) available in current context as fields.
35+
func UnaryClientInterceptor(tracer Tracer) grpc.UnaryClientInterceptor {
36+
return interceptors.UnaryClientInterceptor(reportable(tracer))
37+
}
38+
39+
// StreamClientInterceptor returns a new streaming client interceptor that optionally traces the execution of external gRPC calls.
40+
// Tracer will use tags (from tags package) available in current context as fields.
41+
func StreamClientInterceptor(tracer Tracer) grpc.StreamClientInterceptor {
42+
return interceptors.StreamClientInterceptor(reportable(tracer))
43+
}
44+
45+
// UnaryServerInterceptor returns a new unary server interceptors that optionally traces endpoint handling.
46+
// Tracer will use tags (from tags package) available in current context as fields.
47+
func UnaryServerInterceptor(tracer Tracer) grpc.UnaryServerInterceptor {
48+
return interceptors.UnaryServerInterceptor(reportable(tracer))
49+
}
50+
51+
// StreamServerInterceptor returns a new stream server interceptors that optionally traces endpoint handling.
52+
// Tracer will use tags (from tags package) available in current context as fields.
53+
func StreamServerInterceptor(tracer Tracer) grpc.StreamServerInterceptor {
54+
return interceptors.StreamServerInterceptor(reportable(tracer))
55+
}
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
// Copyright (c) The go-grpc-middleware Authors.
2+
// Licensed under the Apache License 2.0.
3+
4+
package tracing_test
5+
6+
import (
7+
"context"
8+
"io"
9+
"strconv"
10+
"sync/atomic"
11+
"testing"
12+
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
"github.com/stretchr/testify/suite"
16+
"google.golang.org/grpc"
17+
"google.golang.org/grpc/codes"
18+
"google.golang.org/grpc/metadata"
19+
20+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/tracing"
21+
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
22+
)
23+
24+
var (
25+
id int64 = 0
26+
traceIDHeaderKey = "traceid"
27+
spanIDHeaderKey = "spanid"
28+
)
29+
30+
func extractFromContext(ctx context.Context, kind tracing.SpanKind) *mockSpan {
31+
var m metadata.MD
32+
if kind == tracing.SpanKindClient {
33+
m, _ = metadata.FromOutgoingContext(ctx)
34+
} else {
35+
m, _ = metadata.FromIncomingContext(ctx)
36+
}
37+
38+
traceIDValues := m.Get(traceIDHeaderKey)
39+
if len(traceIDValues) == 0 {
40+
return nil
41+
}
42+
spanIDValues := m.Get(spanIDHeaderKey)
43+
if len(spanIDValues) == 0 {
44+
return nil
45+
}
46+
47+
return &mockSpan{
48+
traceID: traceIDValues[0],
49+
spanID: spanIDValues[0],
50+
}
51+
}
52+
53+
func injectWithContext(ctx context.Context, span *mockSpan, kind tracing.SpanKind) context.Context {
54+
var m metadata.MD
55+
if kind == tracing.SpanKindClient {
56+
m, _ = metadata.FromOutgoingContext(ctx)
57+
} else {
58+
m, _ = metadata.FromIncomingContext(ctx)
59+
}
60+
m = m.Copy()
61+
62+
m.Set(traceIDHeaderKey, span.traceID)
63+
m.Set(spanIDHeaderKey, span.spanID)
64+
65+
ctx = metadata.NewOutgoingContext(ctx, m)
66+
return ctx
67+
}
68+
69+
func genID() string {
70+
return strconv.FormatInt(atomic.AddInt64(&id, 1), 10)
71+
}
72+
73+
// Implements Tracker
74+
type mockTracer struct {
75+
spanStore map[string]*mockSpan
76+
}
77+
78+
func (t *mockTracer) ListSpan(kind tracing.SpanKind) []*mockSpan {
79+
var spans []*mockSpan
80+
for _, v := range t.spanStore {
81+
if v.kind == kind {
82+
spans = append(spans, v)
83+
}
84+
}
85+
return spans
86+
}
87+
88+
func (t *mockTracer) Reset() {
89+
t.spanStore = make(map[string]*mockSpan)
90+
}
91+
92+
func newMockTracer() *mockTracer {
93+
return &mockTracer{
94+
spanStore: make(map[string]*mockSpan),
95+
}
96+
}
97+
98+
func (t *mockTracer) Start(ctx context.Context, spanName string, kind tracing.SpanKind) (context.Context, tracing.Span) {
99+
span := mockSpan{
100+
spanID: genID(),
101+
name: spanName,
102+
kind: kind,
103+
statusCode: codes.OK,
104+
}
105+
106+
parentSpan := extractFromContext(ctx, kind)
107+
if parentSpan != nil {
108+
// Fetch span from context as parent span
109+
span.traceID = parentSpan.traceID
110+
span.parentSpanID = parentSpan.spanID
111+
} else {
112+
span.traceID = genID()
113+
}
114+
115+
t.spanStore[span.spanID] = &span
116+
if kind == tracing.SpanKindClient {
117+
ctx = injectWithContext(ctx, &span, kind)
118+
}
119+
return ctx, &span
120+
}
121+
122+
// Implements Span
123+
type mockSpan struct {
124+
traceID string
125+
spanID string
126+
parentSpanID string
127+
128+
name string
129+
kind tracing.SpanKind
130+
end bool
131+
132+
statusCode codes.Code
133+
statusMessage string
134+
135+
msgSendCounter int
136+
msgReceivedCounter int
137+
eventNameList []string
138+
attributesList [][]interface{}
139+
}
140+
141+
func (s *mockSpan) SetAttributes(keyvals ...interface{}) {
142+
s.attributesList = append(s.attributesList, keyvals)
143+
}
144+
145+
func (s *mockSpan) End() {
146+
s.end = true
147+
}
148+
149+
func (s *mockSpan) SetStatus(code codes.Code, message string) {
150+
s.statusCode = code
151+
s.statusMessage = message
152+
}
153+
154+
func (s *mockSpan) AddEvent(name string, keyvals ...interface{}) {
155+
s.eventNameList = append(s.eventNameList, name)
156+
157+
if len(keyvals)%2 == 1 {
158+
keyvals = append(keyvals, nil)
159+
}
160+
161+
for i := 0; i < len(keyvals); i += 2 {
162+
k, keyOK := keyvals[i].(string)
163+
v, valueOK := keyvals[i+1].(string)
164+
165+
if keyOK && valueOK && k == "message.type" {
166+
switch v {
167+
case tracing.RPCMessageTypeSent:
168+
s.msgSendCounter++
169+
case tracing.RPCMessageTypeReceived:
170+
s.msgReceivedCounter++
171+
}
172+
}
173+
}
174+
}
175+
176+
type tracingSuite struct {
177+
*testpb.InterceptorTestSuite
178+
tracer *mockTracer
179+
}
180+
181+
func (s *tracingSuite) BeforeTest(suiteName, testName string) {
182+
s.tracer.Reset()
183+
}
184+
185+
func (s *tracingSuite) TestPing() {
186+
method := "/testing.testpb.v1.TestService/Ping"
187+
errorMethod := "/testing.testpb.v1.TestService/PingError"
188+
t := s.T()
189+
190+
testCases := []struct {
191+
name string
192+
error bool
193+
errorMessage string
194+
}{
195+
{
196+
name: "OK",
197+
error: false,
198+
},
199+
{
200+
name: "invalid argument error",
201+
error: true,
202+
errorMessage: "Userspace error.",
203+
},
204+
}
205+
206+
for _, tc := range testCases {
207+
t.Run(tc.name, func(t *testing.T) {
208+
s.tracer.Reset()
209+
210+
var err error
211+
if tc.error {
212+
req := &testpb.PingErrorRequest{ErrorCodeReturned: uint32(codes.InvalidArgument)}
213+
_, err = s.Client.PingError(s.SimpleCtx(), req)
214+
} else {
215+
req := &testpb.PingRequest{Value: "something"}
216+
_, err = s.Client.Ping(s.SimpleCtx(), req)
217+
}
218+
if tc.error {
219+
require.Error(t, err)
220+
} else {
221+
require.NoError(t, err)
222+
}
223+
224+
clientSpans := s.tracer.ListSpan(tracing.SpanKindClient)
225+
serverSpans := s.tracer.ListSpan(tracing.SpanKindServer)
226+
require.Len(t, clientSpans, 1)
227+
require.Len(t, serverSpans, 1)
228+
229+
clientSpan := clientSpans[0]
230+
assert.True(t, clientSpan.end)
231+
assert.Equal(t, 1, clientSpan.msgSendCounter)
232+
assert.Equal(t, 1, clientSpan.msgReceivedCounter)
233+
assert.Equal(t, []string{"message", "message"}, clientSpan.eventNameList)
234+
235+
serverSpan := serverSpans[0]
236+
assert.True(t, serverSpan.end)
237+
assert.Equal(t, 1, serverSpan.msgSendCounter)
238+
assert.Equal(t, 1, serverSpan.msgReceivedCounter)
239+
assert.Equal(t, []string{"message", "message"}, serverSpan.eventNameList)
240+
241+
assert.Equal(t, clientSpan.traceID, serverSpan.traceID)
242+
assert.Equal(t, clientSpan.spanID, serverSpan.parentSpanID)
243+
244+
if tc.error {
245+
assert.Equal(t, codes.InvalidArgument, clientSpan.statusCode)
246+
assert.Equal(t, tc.errorMessage, clientSpan.statusMessage)
247+
assert.Equal(t, errorMethod, clientSpan.name)
248+
assert.Equal(t, [][]interface{}{{[]interface{}{"rpc.grpc.status_code", int64(3)}}}, clientSpan.attributesList)
249+
250+
assert.Equal(t, errorMethod, serverSpan.name)
251+
assert.Equal(t, [][]interface{}{{[]interface{}{"rpc.grpc.status_code", int64(3)}}}, serverSpan.attributesList)
252+
} else {
253+
assert.Equal(t, codes.OK, clientSpan.statusCode)
254+
assert.Equal(t, method, clientSpan.name)
255+
assert.Equal(t, [][]interface{}{{[]interface{}{"rpc.grpc.status_code", int64(0)}}}, clientSpan.attributesList)
256+
257+
assert.Equal(t, method, serverSpan.name)
258+
assert.Equal(t, [][]interface{}{{[]interface{}{"rpc.grpc.status_code", int64(0)}}}, serverSpan.attributesList)
259+
}
260+
})
261+
}
262+
}
263+
264+
func (s *tracingSuite) TestPingList() {
265+
t := s.T()
266+
method := "/testing.testpb.v1.TestService/PingList"
267+
268+
stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "something"})
269+
require.NoError(t, err)
270+
271+
for {
272+
_, err := stream.Recv()
273+
if err == io.EOF {
274+
break
275+
}
276+
require.NoError(t, err)
277+
}
278+
279+
clientSpans := s.tracer.ListSpan(tracing.SpanKindClient)
280+
serverSpans := s.tracer.ListSpan(tracing.SpanKindServer)
281+
require.Len(t, clientSpans, 1)
282+
require.Len(t, serverSpans, 1)
283+
284+
clientSpan := clientSpans[0]
285+
assert.True(t, clientSpan.end)
286+
assert.Equal(t, 1, clientSpan.msgSendCounter)
287+
assert.Equal(t, testpb.ListResponseCount+1, clientSpan.msgReceivedCounter)
288+
assert.Equal(t, codes.OK, clientSpan.statusCode)
289+
assert.Equal(t, method, clientSpan.name)
290+
291+
serverSpan := serverSpans[0]
292+
assert.True(t, serverSpan.end)
293+
assert.Equal(t, testpb.ListResponseCount, serverSpan.msgSendCounter)
294+
assert.Equal(t, 1, serverSpan.msgReceivedCounter)
295+
assert.Equal(t, codes.OK, serverSpan.statusCode)
296+
assert.Equal(t, method, serverSpan.name)
297+
}
298+
299+
func TestSuite(t *testing.T) {
300+
tracer := newMockTracer()
301+
302+
s := tracingSuite{
303+
InterceptorTestSuite: &testpb.InterceptorTestSuite{
304+
TestService: &testpb.TestPingService{T: t},
305+
},
306+
tracer: tracer,
307+
}
308+
s.InterceptorTestSuite.ClientOpts = []grpc.DialOption{
309+
grpc.WithUnaryInterceptor(tracing.UnaryClientInterceptor(tracer)),
310+
grpc.WithStreamInterceptor(tracing.StreamClientInterceptor(tracer)),
311+
}
312+
s.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{
313+
grpc.ChainUnaryInterceptor(
314+
tracing.UnaryServerInterceptor(tracer),
315+
),
316+
grpc.ChainStreamInterceptor(
317+
tracing.StreamServerInterceptor(tracer),
318+
),
319+
}
320+
321+
suite.Run(t, &s)
322+
}

0 commit comments

Comments
 (0)