Skip to content

Commit ea545dc

Browse files
authored
interceptors: Update logging interceptor Reporter to re-extract fields from context before logging (#702)
When using logging.WithFieldsFromContext, if the value being extracted as a log field is modified after the logging interceptor initializes the Reporter before the underlying handler is called, then the updated value will not be reflected in the log message. To fix this, re-extract fields from the context before logging them in PostCall, PostMsgSend and PostMsgReceive, ensuring the updated values in the context are logged. Signed-off-by: Chance Zibolski <[email protected]>
1 parent 3834477 commit ea545dc

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

interceptors/logging/interceptors.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ func (c *reporter) PostCall(err error, duration time.Duration) {
4141
if err != nil {
4242
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
4343
}
44+
if c.opts.fieldsFromCtxCallMetaFn != nil {
45+
// fieldsFromCtxFn dups override the existing fields.
46+
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
47+
}
4448
c.logger.Log(c.ctx, c.opts.levelFunc(code), "finished call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)
4549
}
4650

@@ -50,6 +54,10 @@ func (c *reporter) PostMsgSend(payload any, err error, duration time.Duration) {
5054
if err != nil {
5155
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
5256
}
57+
if c.opts.fieldsFromCtxCallMetaFn != nil {
58+
// fieldsFromCtxFn dups override the existing fields.
59+
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
60+
}
5361
if !c.startCallLogged && has(c.opts.loggableEvents, StartCall) {
5462
c.startCallLogged = true
5563
c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)
@@ -97,6 +105,10 @@ func (c *reporter) PostMsgReceive(payload any, err error, duration time.Duration
97105
if err != nil {
98106
fields = fields.AppendUnique(Fields{"grpc.error", fmt.Sprintf("%v", err)})
99107
}
108+
if c.opts.fieldsFromCtxCallMetaFn != nil {
109+
// fieldsFromCtxFn dups override the existing fields.
110+
fields = c.opts.fieldsFromCtxCallMetaFn(c.ctx, c.CallMeta).AppendUnique(fields)
111+
}
100112
if !c.startCallLogged && has(c.opts.loggableEvents, StartCall) {
101113
c.startCallLogged = true
102114
c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...)

interceptors/logging/interceptors_test.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"io"
1212
"runtime"
1313
"sort"
14+
"strconv"
1415
"strings"
1516
"sync"
1617
"testing"
@@ -172,9 +173,14 @@ type loggingClientServerSuite struct {
172173
*baseLoggingSuite
173174
}
174175

175-
func customFields(_ context.Context) logging.Fields {
176+
func customFields(ctx context.Context) logging.Fields {
177+
var val string
178+
n := testpb.ExtractCtxTestNumber(ctx)
179+
if n != nil {
180+
val = strconv.Itoa(*n)
181+
}
176182
// Add custom fields. The second one overrides the first one.
177-
return logging.Fields{"custom-field", "foo", "custom-field", "yolo"}
183+
return logging.Fields{"custom-field", "foo", "custom-field", "yolo", "custom-ctx-field", val}
178184
}
179185

180186
func TestSuite(t *testing.T) {
@@ -232,13 +238,17 @@ func (s *loggingClientServerSuite) TestPing() {
232238
assert.Equal(s.T(), logging.LevelDebug, serverStartCallLogLine.lvl)
233239
assert.Equal(s.T(), "started call", serverStartCallLogLine.msg)
234240
_ = assertStandardFields(s.T(), logging.KindServerFieldValue, serverStartCallLogLine.fields, "Ping", interceptors.Unary)
241+
// This field is zero initially, but will be updated by the service, which we should see after the call is finished
242+
serverStartCallLogLine.fields.AssertField(s.T(), "custom-ctx-field", "0")
235243

236244
serverFinishCallLogLine := lines[2]
237245
assert.Equal(s.T(), logging.LevelDebug, serverFinishCallLogLine.lvl)
238246
assert.Equal(s.T(), "finished call", serverFinishCallLogLine.msg)
239247
serverFinishCallFields := assertStandardFields(s.T(), logging.KindServerFieldValue, serverFinishCallLogLine.fields, "Ping", interceptors.Unary)
240248
serverFinishCallFields.AssertFieldNotEmpty(s.T(), "peer.address").
241249
AssertField(s.T(), "custom-field", "yolo").
250+
// should be updated from 0 to 42
251+
AssertField(s.T(), "custom-ctx-field", "42").
242252
AssertFieldNotEmpty(s.T(), "grpc.start_time").
243253
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
244254
AssertField(s.T(), "grpc.code", "OK").
@@ -249,6 +259,8 @@ func (s *loggingClientServerSuite) TestPing() {
249259
assert.Equal(s.T(), "finished call", clientFinishCallLogLine.msg)
250260
clientFinishCallFields := assertStandardFields(s.T(), logging.KindClientFieldValue, clientFinishCallLogLine.fields, "Ping", interceptors.Unary)
251261
clientFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
262+
// should be updated from 0 to 42
263+
AssertField(s.T(), "custom-ctx-field", "42").
252264
AssertField(s.T(), "grpc.request.value", "something").
253265
AssertFieldNotEmpty(s.T(), "grpc.start_time").
254266
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
@@ -285,6 +297,8 @@ func (s *loggingClientServerSuite) TestPingList() {
285297
assert.Equal(s.T(), "finished call", serverFinishCallLogLine.msg)
286298
serverFinishCallFields := assertStandardFields(s.T(), logging.KindServerFieldValue, serverFinishCallLogLine.fields, "PingList", interceptors.ServerStream)
287299
serverFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
300+
// should be updated from 0 to 42
301+
AssertField(s.T(), "custom-ctx-field", "42").
288302
AssertFieldNotEmpty(s.T(), "peer.address").
289303
AssertFieldNotEmpty(s.T(), "grpc.start_time").
290304
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
@@ -297,6 +311,8 @@ func (s *loggingClientServerSuite) TestPingList() {
297311
clientFinishCallFields := assertStandardFields(s.T(), logging.KindClientFieldValue, clientFinishCallLogLine.fields, "PingList", interceptors.ServerStream)
298312
clientFinishCallFields.AssertFieldNotEmpty(s.T(), "grpc.start_time").
299313
AssertField(s.T(), "custom-field", "yolo").
314+
// should be updated from 0 to 42
315+
AssertField(s.T(), "custom-ctx-field", "42").
300316
AssertFieldNotEmpty(s.T(), "grpc.request.deadline").
301317
AssertField(s.T(), "grpc.code", "OK").
302318
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())
@@ -344,23 +360,27 @@ func (s *loggingClientServerSuite) TestPingError_WithCustomLevels() {
344360
assert.Equal(t, "finished call", serverFinishCallLogLine.msg)
345361
serverFinishCallFields := assertStandardFields(t, logging.KindServerFieldValue, serverFinishCallLogLine.fields, "PingError", interceptors.Unary)
346362
serverFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
363+
// should be updated from 0 to 42
364+
AssertField(s.T(), "custom-ctx-field", "42").
347365
AssertFieldNotEmpty(t, "peer.address").
348366
AssertFieldNotEmpty(t, "grpc.start_time").
349367
AssertFieldNotEmpty(t, "grpc.request.deadline").
350368
AssertField(t, "grpc.code", tcase.code.String()).
351369
AssertField(t, "grpc.error", fmt.Sprintf("rpc error: code = %s desc = Userspace error", tcase.code.String())).
352-
AssertFieldNotEmpty(t, "grpc.time_ms").AssertNoMoreTags(t)
370+
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())
353371

354372
clientFinishCallLogLine := lines[0]
355373
assert.Equal(t, tcase.level, clientFinishCallLogLine.lvl)
356374
assert.Equal(t, "finished call", clientFinishCallLogLine.msg)
357375
clientFinishCallFields := assertStandardFields(t, logging.KindClientFieldValue, clientFinishCallLogLine.fields, "PingError", interceptors.Unary)
358376
clientFinishCallFields.AssertField(s.T(), "custom-field", "yolo").
377+
// should be updated from 0 to 42
378+
AssertField(s.T(), "custom-ctx-field", "42").
359379
AssertFieldNotEmpty(t, "grpc.start_time").
360380
AssertFieldNotEmpty(t, "grpc.request.deadline").
361381
AssertField(t, "grpc.code", tcase.code.String()).
362382
AssertField(t, "grpc.error", fmt.Sprintf("rpc error: code = %s desc = Userspace error", tcase.code.String())).
363-
AssertFieldNotEmpty(t, "grpc.time_ms").AssertNoMoreTags(t)
383+
AssertFieldNotEmpty(s.T(), "grpc.time_ms").AssertNoMoreTags(s.T())
364384
})
365385
}
366386
}

testing/testpb/interceptor_suite.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,33 @@ func (s *InterceptorTestSuite) ServerAddr() string {
136136
return s.serverAddr
137137
}
138138

139+
type ctxTestNumber struct{}
140+
141+
var (
142+
ctxTestNumberKey = &ctxTestNumber{}
143+
zero = 0
144+
)
145+
146+
func ExtractCtxTestNumber(ctx context.Context) *int {
147+
if v, ok := ctx.Value(ctxTestNumberKey).(*int); ok {
148+
return v
149+
}
150+
return &zero
151+
}
152+
153+
// UnaryServerInterceptor returns a new unary server interceptors that adds query information logging.
154+
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
155+
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
156+
// newCtx := newContext(ctx, log, opts)
157+
newCtx := ctx
158+
resp, err := handler(newCtx, req)
159+
return resp, err
160+
}
161+
}
162+
139163
func (s *InterceptorTestSuite) SimpleCtx() context.Context {
140164
ctx, cancel := context.WithTimeout(context.TODO(), 2*time.Second)
165+
ctx = context.WithValue(ctx, ctxTestNumberKey, 1)
141166
s.cancels = append(s.cancels, cancel)
142167
return ctx
143168
}

testing/testpb/pingservice.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ func (s *TestPingService) PingEmpty(_ context.Context, _ *PingEmptyRequest) (*Pi
3333
return &PingEmptyResponse{}, nil
3434
}
3535

36-
func (s *TestPingService) Ping(_ context.Context, ping *PingRequest) (*PingResponse, error) {
36+
func (s *TestPingService) Ping(ctx context.Context, ping *PingRequest) (*PingResponse, error) {
37+
// Modify the ctx value to verify the logger sees the value updated from the initial value
38+
n := ExtractCtxTestNumber(ctx)
39+
if n != nil {
40+
*n = 42
41+
}
3742
// Send user trailers and headers.
3843
return &PingResponse{Value: ping.Value, Counter: 0}, nil
3944
}

0 commit comments

Comments
 (0)