diff --git a/experimental/experimental.go b/experimental/experimental.go index 719692636505..87db588c9844 100644 --- a/experimental/experimental.go +++ b/experimental/experimental.go @@ -62,3 +62,13 @@ func WithBufferPool(bufferPool mem.BufferPool) grpc.DialOption { func BufferPool(bufferPool mem.BufferPool) grpc.ServerOption { return internal.BufferPool.(func(mem.BufferPool) grpc.ServerOption)(bufferPool) } + +// AcceptedCompressionNames returns a CallOption that limits the values +// advertised in the grpc-accept-encoding header for the provided RPC. The +// supplied names must correspond to compressors registered via +// encoding.RegisterCompressor. Passing no names advertises identity only. +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later release. +func AcceptedCompressionNames(names ...string) grpc.CallOption { + return internal.AcceptedCompressionNames.(func(...string) grpc.CallOption)(names...) +} diff --git a/internal/experimental.go b/internal/experimental.go index 7617be215895..3482abacdc5e 100644 --- a/internal/experimental.go +++ b/internal/experimental.go @@ -25,4 +25,8 @@ var ( // BufferPool is implemented by the grpc package and returns a server // option to configure a shared buffer pool for a grpc.Server. BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption + + // AcceptedCompressionNames is implemented by the grpc package and returns + // a call option that restricts the grpc-accept-encoding header for a call. + AcceptedCompressionNames any // func(...string) grpc.CallOption ) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 65b4ab2439e2..556b5192ea1e 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -321,12 +321,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts maxHeaderListSize = *opts.MaxHeaderListSize } + registeredCompressors := grpcutil.RegisteredCompressors() + t := &http2Client{ ctx: ctx, ctxDone: ctx.Done(), // Cache Done chan. cancel: cancel, userAgent: opts.UserAgent, - registeredCompressors: grpcutil.RegisteredCompressors(), + registeredCompressors: registeredCompressors, address: addr, conn: conn, remoteAddr: conn.RemoteAddr(), @@ -551,6 +553,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te hfLen += len(authData) + len(callAuthData) registeredCompressors := t.registeredCompressors + if callHdr.AcceptedCompressors != nil { + registeredCompressors = *callHdr.AcceptedCompressors + } if callHdr.PreviousAttempts > 0 { hfLen++ } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 5ff83a7d7d74..e1e466698e34 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -553,6 +553,12 @@ type CallHdr struct { // outbound message. SendCompress string + // AcceptedCompressors overrides the grpc-accept-encoding header for this + // call. When nil, the transport advertises the default set of registered + // compressors. A non-nil pointer overrides that value (including the empty + // string to advertise none). + AcceptedCompressors *string + // Creds specifies credentials.PerRPCCredentials for a call. Creds credentials.PerRPCCredentials diff --git a/rpc_util.go b/rpc_util.go index 6b04c9e87357..32a5c0f55bda 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -33,6 +33,8 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" @@ -41,6 +43,10 @@ import ( "google.golang.org/grpc/status" ) +func init() { + internal.AcceptedCompressionNames = acceptedCompressionNames +} + // Compressor defines the interface gRPC uses to compress a message. // // Deprecated: use package encoding. @@ -151,16 +157,33 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { - compressorName string - failFast bool - maxReceiveMessageSize *int - maxSendMessageSize *int - creds credentials.PerRPCCredentials - contentSubtype string - codec baseCodec - maxRetryRPCBufferSize int - onFinish []func(err error) - authority string + compressorName string + failFast bool + maxReceiveMessageSize *int + maxSendMessageSize *int + creds credentials.PerRPCCredentials + contentSubtype string + codec baseCodec + maxRetryRPCBufferSize int + onFinish []func(err error) + authority string + acceptedResponseCompressors *acceptedCompressionConfig +} + +type acceptedCompressionConfig struct { + headerValue string + allowed map[string]struct{} +} + +func (cfg *acceptedCompressionConfig) allows(name string) bool { + if cfg == nil { + return true + } + if name == "" || name == encoding.Identity { + return true + } + _, ok := cfg.allowed[name] + return ok } func defaultCallInfo() *callInfo { @@ -170,6 +193,35 @@ func defaultCallInfo() *callInfo { } } +func newAcceptedCompressionConfig(names []string) (*acceptedCompressionConfig, error) { + cfg := &acceptedCompressionConfig{ + allowed: make(map[string]struct{}, len(names)), + } + if len(names) == 0 { + return cfg, nil + } + var ordered []string + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" || name == encoding.Identity { + continue + } + if !grpcutil.IsCompressorNameRegistered(name) { + return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name) + } + if _, dup := cfg.allowed[name]; dup { + continue + } + cfg.allowed[name] = struct{}{} + ordered = append(ordered, name) + } + if len(ordered) == 0 { + return nil, status.Error(codes.InvalidArgument, "grpc: no valid compressor names provided") + } + cfg.headerValue = strings.Join(ordered, ",") + return cfg, nil +} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -471,6 +523,26 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} +func acceptedCompressionNames(names ...string) CallOption { + cp := append([]string(nil), names...) + return acceptedCompressionNamesCallOption{names: cp} +} + +type acceptedCompressionNamesCallOption struct { + names []string +} + +func (o acceptedCompressionNamesCallOption) before(c *callInfo) error { + cfg, err := newAcceptedCompressionConfig(o.names) + if err != nil { + return err + } + c.acceptedResponseCompressors = cfg + return nil +} + +func (acceptedCompressionNamesCallOption) after(*callInfo, *csAttempt) {} + // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over // the wire will be "application/grpc+json". The content-subtype is converted @@ -821,7 +893,7 @@ func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time } } -func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status { +func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool, acceptedCfg *acceptedCompressionConfig) *status.Status { switch pf { case compressionNone: case compressionMade: @@ -834,6 +906,9 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool } return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } + if !isServer && acceptedCfg != nil && !acceptedCfg.allows(recvCompress) { + return status.Newf(codes.FailedPrecondition, "grpc: peer compressed the response with %q which is not allowed by AcceptedCompressionNames", recvCompress) + } default: return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) } @@ -857,7 +932,7 @@ func (p *payloadInfo) free() { // the buffer is no longer needed. // TODO: Refactor this function to reduce the number of arguments. // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists -func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, +func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig, ) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { @@ -866,7 +941,7 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM compressedLength := compressed.Len() - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil { + if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer, acceptedCfg); st != nil { compressed.Free() return nil, st.Err() } @@ -941,8 +1016,8 @@ type recvCompressor interface { // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error { - data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer) +func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, acceptedCfg *acceptedCompressionConfig) error { + data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer, acceptedCfg) if err != nil { return err } diff --git a/rpc_util_test.go b/rpc_util_test.go index a5c5cb8b17e2..a9da704e7303 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -48,6 +48,85 @@ const ( decompressionErrorMsg = "invalid compression format" ) +func (s) TestNewAcceptedCompressionConfig(t *testing.T) { + tests := []struct { + name string + input []string + wantHeader string + wantAllowed map[string]struct{} + wantErr bool + }{ + { + name: "identity-only", + input: nil, + wantHeader: "", + wantAllowed: map[string]struct{}{}, + }, + { + name: "single valid", + input: []string{"gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "dedupe and trim", + input: []string{" gzip ", "gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "ignores identity", + input: []string{"identity", "gzip"}, + wantHeader: "gzip", + wantAllowed: map[string]struct{}{"gzip": {}}, + }, + { + name: "invalid compressor", + input: []string{"does-not-exist"}, + wantErr: true, + }, + { + name: "only whitespace", + input: []string{" ", "\t"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := newAcceptedCompressionConfig(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("newAcceptedCompressionConfig(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if tt.wantErr { + return + } + if cfg.headerValue != tt.wantHeader { + t.Fatalf("headerValue = %q, want %q", cfg.headerValue, tt.wantHeader) + } + if diff := cmp.Diff(tt.wantAllowed, cfg.allowed); diff != "" { + t.Fatalf("allowed diff (-want +got): %v", diff) + } + }) + } +} + +func (s) TestCheckRecvPayloadHonorsAcceptedCompressors(t *testing.T) { + cfg, err := newAcceptedCompressionConfig([]string{"gzip"}) + if err != nil { + t.Fatalf("newAcceptedCompressionConfig returned error: %v", err) + } + + if st := checkRecvPayload(compressionMade, "gzip", true, false, cfg); st != nil { + t.Fatalf("checkRecvPayload returned error for allowed compressor: %v", st) + } + + st := checkRecvPayload(compressionMade, "snappy", true, false, cfg) + if st == nil || st.Code() != codes.FailedPrecondition { + t.Fatalf("checkRecvPayload = %v, want code %v", st, codes.FailedPrecondition) + } +} + type fullReader struct { data []byte } diff --git a/server.go b/server.go index ddd377341191..2099cff4be9d 100644 --- a/server.go +++ b/server.go @@ -1381,7 +1381,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt defer payInfo.free() } - d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true) + d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true, nil) if err != nil { if e := stream.WriteStatus(status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e) diff --git a/stream.go b/stream.go index ca87ff9776ef..a9feb0593d03 100644 --- a/stream.go +++ b/stream.go @@ -301,6 +301,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client DoneFunc: doneFunc, Authority: callInfo.authority, } + if cfg := callInfo.acceptedResponseCompressors; cfg != nil { + callHdr.AcceptedCompressors = &cfg.headerValue + } // Set our outgoing compression according to the UseCompressor CallOption, if // set. In that case, also find the compressor from the encoding package. @@ -1141,7 +1144,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { // Only initialize this state once per stream. a.decompressorSet = true } - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err != nil { if err == io.EOF { if statusErr := a.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1179,7 +1182,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false, cs.callInfo.acceptedResponseCompressors); err == io.EOF { return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1486,7 +1489,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Only initialize this state once per stream. as.decompressorSet = true } - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err != nil { if err == io.EOF { if statusErr := as.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1508,7 +1511,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false, as.callInfo.acceptedResponseCompressors); err == io.EOF { return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1785,7 +1788,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { payInfo = &payloadInfo{} defer payInfo.free() } - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true, nil); err != nil { if err == io.EOF { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} @@ -1829,7 +1832,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } // Special handling for non-client-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true, nil); err == io.EOF { return nil } else if err != nil { return err diff --git a/test/compressor_test.go b/test/compressor_test.go index dbdc06222220..fb80206b164e 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/experimental" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -533,6 +534,30 @@ func (s) TestClientSupportedCompressors(t *testing.T) { } } +func (s) TestAcceptedCompressionNamesCallOption(t *testing.T) { + const want = "gzip" + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + md, _ := metadata.FromIncomingContext(ctx) + if got := md.Get("grpc-accept-encoding"); len(got) != 1 || got[0] != want { + t.Fatalf("unexpected grpc-accept-encoding header: %v", got) + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("failed to start server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, experimental.AcceptedCompressionNames(want)); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } +} + func (s) TestCompressorRegister(t *testing.T) { for _, e := range listTestEnv() { testCompressorRegister(t, e)