Skip to content

Commit 6ed3f90

Browse files
authored
flags: Allow configuring arbitrary gRPC headers (#3113)
2 parents fe2bffd + bbc869e commit 6ed3f90

File tree

2 files changed

+53
-18
lines changed

2 files changed

+53
-18
lines changed

flags/flags.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,12 @@ type FlagsRemoteStore struct {
334334
RPCLoggingEnable bool `default:"false" help:"[deprecated] Enable gRPC logging."`
335335
RPCUnaryTimeout time.Duration `default:"5m" help:"[deprecated] Maximum timeout window for unary gRPC requests including retries."`
336336

337-
GRPCMaxCallRecvMsgSize int `default:"33554432" help:"The maximum message size the client can receive."`
338-
GRPCMaxCallSendMsgSize int `default:"33554432" help:"The maximum message size the client can send."`
339-
GRPCStartupBackoffTime time.Duration `default:"1m" help:"The time between failed gRPC requests during startup phase."`
340-
GRPCConnectionTimeout time.Duration `default:"3s" help:"The timeout duration for gRPC connection establishment."`
341-
GRPCMaxConnectionRetries uint32 `default:"5" help:"The maximum number of retries to establish a gRPC connection."`
337+
GRPCMaxCallRecvMsgSize int `default:"33554432" help:"The maximum message size the client can receive."`
338+
GRPCMaxCallSendMsgSize int `default:"33554432" help:"The maximum message size the client can send."`
339+
GRPCStartupBackoffTime time.Duration `default:"1m" help:"The time between failed gRPC requests during startup phase."`
340+
GRPCConnectionTimeout time.Duration `default:"3s" help:"The timeout duration for gRPC connection establishment."`
341+
GRPCMaxConnectionRetries uint32 `default:"5" help:"The maximum number of retries to establish a gRPC connection."`
342+
GRPCHeaders map[string]string `help:"Additional gRPC headers to send with each request (key=value pairs)."`
342343
}
343344

344345
// FlagsDebuginfo contains flags to configure debuginfo.

flags/grpc.go

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"google.golang.org/grpc/credentials"
2323
"google.golang.org/grpc/credentials/insecure"
2424
"google.golang.org/grpc/encoding"
25+
"google.golang.org/grpc/metadata"
2526
)
2627

2728
// WaitGrpcEndpoint waits until the gRPC connection is established.
@@ -123,10 +124,10 @@ func (f FlagsRemoteStore) setupGrpcConnection(parent context.Context, metrics *g
123124
}
124125
propagators := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})
125126

126-
opts = append(opts,
127-
grpc.WithChainUnaryInterceptor(
128-
timeout.UnaryClientInterceptor(f.RPCUnaryTimeout), // 5m by default.
129-
retry.UnaryClientInterceptor(
127+
// Build interceptor chain
128+
unaryInterceptors := []grpc.UnaryClientInterceptor{
129+
timeout.UnaryClientInterceptor(f.RPCUnaryTimeout), // 5m by default.
130+
retry.UnaryClientInterceptor(
130131
// Back-off with Jitter: scalar: 1s, jitterFraction: 0,1, 10 runs
131132
// i: 1 t:969.91774ms total:969.91774ms
132133
// i: 2 t:1.914221005s total:2.884138745s
@@ -146,17 +147,28 @@ func (f FlagsRemoteStore) setupGrpcConnection(parent context.Context, metrics *g
146147
// `WithPerRetryTimeout` allows you to shorten the deadline of each retry call, allowing you to fit multiple retries in the single parent deadline.
147148
retry.WithPerRetryTimeout(2*time.Minute),
148149
),
149-
metrics.UnaryClientInterceptor(
150-
grpc_prometheus.WithExemplarFromContext(exemplarFromContext),
151-
),
152-
logging.UnaryClientInterceptor(interceptorLogger(), logging.WithFieldsFromContext(logTraceID)),
150+
metrics.UnaryClientInterceptor(
151+
grpc_prometheus.WithExemplarFromContext(exemplarFromContext),
153152
),
154-
grpc.WithChainStreamInterceptor(
155-
metrics.StreamClientInterceptor(
156-
grpc_prometheus.WithExemplarFromContext(exemplarFromContext),
157-
),
158-
logging.StreamClientInterceptor(interceptorLogger(), logging.WithFieldsFromContext(logTraceID)),
153+
logging.UnaryClientInterceptor(interceptorLogger(), logging.WithFieldsFromContext(logTraceID)),
154+
}
155+
156+
streamInterceptors := []grpc.StreamClientInterceptor{
157+
metrics.StreamClientInterceptor(
158+
grpc_prometheus.WithExemplarFromContext(exemplarFromContext),
159159
),
160+
logging.StreamClientInterceptor(interceptorLogger(), logging.WithFieldsFromContext(logTraceID)),
161+
}
162+
163+
// Add custom headers interceptor if headers are configured
164+
if len(f.GRPCHeaders) > 0 {
165+
unaryInterceptors = append([]grpc.UnaryClientInterceptor{customHeadersUnaryInterceptor(f.GRPCHeaders)}, unaryInterceptors...)
166+
streamInterceptors = append([]grpc.StreamClientInterceptor{customHeadersStreamInterceptor(f.GRPCHeaders)}, streamInterceptors...)
167+
}
168+
169+
opts = append(opts,
170+
grpc.WithChainUnaryInterceptor(unaryInterceptors...),
171+
grpc.WithChainStreamInterceptor(streamInterceptors...),
160172
grpc.WithStatsHandler(tracing.NewClientHandler(
161173
tracing.WithTracerProvider(tp),
162174
tracing.WithPropagators(propagators),
@@ -191,6 +203,28 @@ func (t *perRequestBearerToken) RequireTransportSecurity() bool {
191203
return !t.insecure
192204
}
193205

206+
// customHeadersUnaryInterceptor adds custom headers to all unary RPC calls.
207+
func customHeadersUnaryInterceptor(headers map[string]string) grpc.UnaryClientInterceptor {
208+
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
209+
// Add headers to outgoing context
210+
for key, value := range headers {
211+
ctx = metadata.AppendToOutgoingContext(ctx, key, value)
212+
}
213+
return invoker(ctx, method, req, reply, cc, opts...)
214+
}
215+
}
216+
217+
// customHeadersStreamInterceptor adds custom headers to all streaming RPC calls.
218+
func customHeadersStreamInterceptor(headers map[string]string) grpc.StreamClientInterceptor {
219+
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
220+
// Add headers to outgoing context
221+
for key, value := range headers {
222+
ctx = metadata.AppendToOutgoingContext(ctx, key, value)
223+
}
224+
return streamer(ctx, desc, cc, method, opts...)
225+
}
226+
}
227+
194228
// interceptorLogger adapts go-kit logger to interceptor logger.
195229
func interceptorLogger() logging.Logger {
196230
return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) {

0 commit comments

Comments
 (0)