Skip to content

Commit 7235bb7

Browse files
authored
encoding: Add a test-only function for temporarily registering compressors (#8587)
Fixes: #7960 This PR adds a function that allows tests to register a compressor with arbitrary names and un-register them at the end of the test. This prevents the compressor names from showing up in the encoding header in subsequent tests. Previously, tests were using the name of the existing compressor "gzip" and re-registering the original compressor to workaround this problem. RELEASE NOTES: N/A
1 parent 5028ef7 commit 7235bb7

File tree

5 files changed

+288
-206
lines changed

5 files changed

+288
-206
lines changed

encoding/compressor_test.go

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
*
3+
* Copyright 2025 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
package encoding_test
20+
21+
import (
22+
"bytes"
23+
"context"
24+
"io"
25+
"sync/atomic"
26+
"testing"
27+
28+
"google.golang.org/grpc"
29+
"google.golang.org/grpc/codes"
30+
"google.golang.org/grpc/encoding"
31+
"google.golang.org/grpc/encoding/internal"
32+
"google.golang.org/grpc/internal/stubserver"
33+
"google.golang.org/grpc/status"
34+
35+
testgrpc "google.golang.org/grpc/interop/grpc_testing"
36+
testpb "google.golang.org/grpc/interop/grpc_testing"
37+
38+
_ "google.golang.org/grpc/encoding/gzip"
39+
)
40+
41+
// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
42+
// Compressor method invokes.
43+
type wrapCompressor struct {
44+
encoding.Compressor
45+
compressInvokes int32
46+
}
47+
48+
func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
49+
atomic.AddInt32(&wc.compressInvokes, 1)
50+
return wc.Compressor.Compress(w)
51+
}
52+
53+
func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
54+
regFn := internal.RegisterCompressorForTesting.(func(encoding.Compressor) func())
55+
c := &wrapCompressor{Compressor: encoding.GetCompressor("gzip")}
56+
unreg := regFn(c)
57+
t.Cleanup(unreg)
58+
return c
59+
}
60+
61+
func (s) TestSetSendCompressorSuccess(t *testing.T) {
62+
for _, tt := range []struct {
63+
name string
64+
desc string
65+
payload *testpb.Payload
66+
dialOpts []grpc.DialOption
67+
resCompressor string
68+
wantCompressInvokes int32
69+
}{
70+
{
71+
name: "identity_request_and_gzip_response",
72+
desc: "request is uncompressed and response is gzip compressed",
73+
payload: &testpb.Payload{Body: []byte("payload")},
74+
resCompressor: "gzip",
75+
wantCompressInvokes: 1,
76+
},
77+
{
78+
name: "identity_request_and_empty_response",
79+
desc: "request is uncompressed and response is gzip compressed",
80+
payload: nil,
81+
resCompressor: "gzip",
82+
wantCompressInvokes: 0,
83+
},
84+
{
85+
name: "gzip_request_and_identity_response",
86+
desc: "request is gzip compressed and response is uncompressed with identity",
87+
payload: &testpb.Payload{Body: []byte("payload")},
88+
resCompressor: "identity",
89+
dialOpts: []grpc.DialOption{
90+
// Use WithCompressor instead of UseCompressor to avoid counting
91+
// the client's compressor usage.
92+
grpc.WithCompressor(grpc.NewGZIPCompressor()),
93+
},
94+
wantCompressInvokes: 0,
95+
},
96+
} {
97+
t.Run(tt.name, func(t *testing.T) {
98+
t.Run("unary", func(t *testing.T) {
99+
testUnarySetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
100+
})
101+
102+
t.Run("stream", func(t *testing.T) {
103+
testStreamSetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
104+
})
105+
})
106+
}
107+
}
108+
109+
func testUnarySetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
110+
wc := setupGzipWrapCompressor(t)
111+
ss := &stubserver.StubServer{
112+
UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
113+
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
114+
return nil, err
115+
}
116+
return &testpb.SimpleResponse{
117+
Payload: payload,
118+
}, nil
119+
},
120+
}
121+
if err := ss.Start(nil, dialOpts...); err != nil {
122+
t.Fatalf("Error starting endpoint server: %v", err)
123+
}
124+
defer ss.Stop()
125+
126+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
127+
defer cancel()
128+
129+
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
130+
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
131+
}
132+
133+
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
134+
if compressInvokes != wantCompressInvokes {
135+
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
136+
}
137+
}
138+
139+
func testStreamSetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
140+
wc := setupGzipWrapCompressor(t)
141+
ss := &stubserver.StubServer{
142+
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
143+
if _, err := stream.Recv(); err != nil {
144+
return err
145+
}
146+
147+
if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
148+
return err
149+
}
150+
151+
return stream.Send(&testpb.StreamingOutputCallResponse{
152+
Payload: payload,
153+
})
154+
},
155+
}
156+
if err := ss.Start(nil, dialOpts...); err != nil {
157+
t.Fatalf("Error starting endpoint server: %v", err)
158+
}
159+
defer ss.Stop()
160+
161+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
162+
defer cancel()
163+
164+
s, err := ss.Client.FullDuplexCall(ctx)
165+
if err != nil {
166+
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
167+
}
168+
169+
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
170+
t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
171+
}
172+
173+
if _, err := s.Recv(); err != nil {
174+
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
175+
}
176+
177+
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
178+
if compressInvokes != wantCompressInvokes {
179+
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
180+
}
181+
}
182+
183+
// fakeCompressor returns a messages of a configured size, irrespective of the
184+
// input.
185+
type fakeCompressor struct {
186+
decompressedMessageSize int
187+
}
188+
189+
func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
190+
return nopWriteCloser{w}, nil
191+
}
192+
193+
func (f *fakeCompressor) Decompress(io.Reader) (io.Reader, error) {
194+
return bytes.NewReader(make([]byte, f.decompressedMessageSize)), nil
195+
}
196+
197+
func (f *fakeCompressor) Name() string {
198+
// Use the name of an existing compressor to avoid interactions with other
199+
// tests since compressors can't be un-registered.
200+
return "fake"
201+
}
202+
203+
type nopWriteCloser struct {
204+
io.Writer
205+
}
206+
207+
func (nopWriteCloser) Close() error {
208+
return nil
209+
}
210+
211+
// TestDecompressionExceedsMaxMessageSize uses a fake compressor that produces
212+
// messages of size 100 bytes on decompression. A server is started with the
213+
// max receive message size restricted to 99 bytes. The test verifies that the
214+
// client receives a ResourceExhausted response from the server.
215+
func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
216+
const messageLen = 100
217+
regFn := internal.RegisterCompressorForTesting.(func(encoding.Compressor) func())
218+
compressor := &fakeCompressor{decompressedMessageSize: messageLen}
219+
unreg := regFn(compressor)
220+
defer unreg()
221+
ss := &stubserver.StubServer{
222+
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
223+
return &testpb.SimpleResponse{}, nil
224+
},
225+
}
226+
if err := ss.Start([]grpc.ServerOption{grpc.MaxRecvMsgSize(messageLen - 1)}); err != nil {
227+
t.Fatalf("Error starting endpoint server: %v", err)
228+
}
229+
defer ss.Stop()
230+
231+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
232+
defer cancel()
233+
234+
req := &testpb.SimpleRequest{Payload: &testpb.Payload{}}
235+
_, err := ss.Client.UnaryCall(ctx, req, grpc.UseCompressor(compressor.Name()))
236+
if got, want := status.Code(err), codes.ResourceExhausted; got != want {
237+
t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want)
238+
}
239+
}

encoding/encoding.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,35 @@ package encoding
2727

2828
import (
2929
"io"
30+
"slices"
3031
"strings"
3132

33+
"google.golang.org/grpc/encoding/internal"
3234
"google.golang.org/grpc/internal/grpcutil"
3335
)
3436

3537
// Identity specifies the optional encoding for uncompressed streams.
3638
// It is intended for grpc internal use only.
3739
const Identity = "identity"
3840

41+
func init() {
42+
internal.RegisterCompressorForTesting = func(c Compressor) func() {
43+
name := c.Name()
44+
curCompressor, found := registeredCompressor[name]
45+
RegisterCompressor(c)
46+
return func() {
47+
if found {
48+
registeredCompressor[name] = curCompressor
49+
return
50+
}
51+
delete(registeredCompressor, name)
52+
grpcutil.RegisteredCompressorNames = slices.DeleteFunc(grpcutil.RegisteredCompressorNames, func(s string) bool {
53+
return s == name
54+
})
55+
}
56+
}
57+
}
58+
3959
// Compressor is used for compressing and decompressing when sending or
4060
// receiving messages.
4161
//

encoding/encoding_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (s) TestDuplicateCompressorRegister(t *testing.T) {
7777
t.Fatalf("Unexpected compressor, got: %+v, want:%+v", got, mc)
7878
}
7979

80-
wantNames := []string{"mock-compressor"}
80+
wantNames := []string{"gzip", "mock-compressor"}
8181
if !cmp.Equal(wantNames, grpcutil.RegisteredCompressorNames) {
8282
t.Fatalf("Unexpected compressor names, got: %+v, want:%+v", grpcutil.RegisteredCompressorNames, wantNames)
8383
}

encoding/internal/internal.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
*
3+
* Copyright 2025 gRPC authors.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
*/
18+
19+
// Package internal contains code internal to the encoding package.
20+
package internal
21+
22+
// RegisterCompressorForTesting registers a compressor in the global compressor
23+
// registry. It returns a cleanup function that should be called at the end
24+
// of the test to unregister the compressor.
25+
//
26+
// This prevents compressors registered in one test from appearing in the
27+
// encoding headers of subsequent tests.
28+
var RegisterCompressorForTesting any // func RegisterCompressor(c Compressor) func()

0 commit comments

Comments
 (0)