@@ -20,7 +20,8 @@ import (
2020 "bytes"
2121 "errors"
2222 "io"
23- "io/ioutil"
23+ "runtime"
24+ "sync"
2425
2526 "github.com/klauspost/compress/zstd"
2627 "google.golang.org/grpc/encoding"
@@ -34,9 +35,22 @@ var encoderOptions = []zstd.EOption{
3435 zstd .WithWindowSize (512 * 1024 ),
3536}
3637
38+ var decoderOptions = []zstd.DOption {
39+ // If the decoder concurrency level is not 1, we would need to call
40+ // Close() to avoid leaking resources when the object is released
41+ // from compressor.decoderPool.
42+ zstd .WithDecoderConcurrency (1 ),
43+ }
44+
45+ // We will set a finalizer on these objects, so when the go-grpc code is
46+ // finished with them, they will be added back to compressor.decoderPool.
47+ type decoderWrapper struct {
48+ * zstd.Decoder
49+ }
50+
3751type compressor struct {
38- encoder * zstd.Encoder
39- decoder * zstd.Decoder
52+ encoder * zstd.Encoder
53+ decoderPool sync. Pool // To hold *zstd.Decoder's.
4054}
4155
4256func PretendInit (clobbering bool ) {
@@ -45,10 +59,8 @@ func PretendInit(clobbering bool) {
4559 }
4660
4761 enc , _ := zstd .NewWriter (nil , encoderOptions ... )
48- dec , _ := zstd .NewReader (nil )
4962 c := & compressor {
5063 encoder : enc ,
51- decoder : dec ,
5264 }
5365 encoding .RegisterCompressor (c )
5466}
@@ -97,17 +109,36 @@ func (z *zstdWriteCloser) Close() error {
97109}
98110
99111func (c * compressor ) Decompress (r io.Reader ) (io.Reader , error ) {
100- compressed , err := ioutil .ReadAll (r )
101- if err != nil {
102- return nil , err
112+ var err error
113+ var found bool
114+ var decoder * zstd.Decoder
115+
116+ // Note: avoid the use of zstd.Decoder.DecodeAll here, since
117+ // malicious payloads could DoS us with a decompression bomb.
118+
119+ decoder , found = c .decoderPool .Get ().(* zstd.Decoder )
120+ if ! found {
121+ decoder , err = zstd .NewReader (r , decoderOptions ... )
122+ if err != nil {
123+ return nil , err
124+ }
125+ } else {
126+ err = decoder .Reset (r )
127+ if err != nil {
128+ c .decoderPool .Put (decoder )
129+ return nil , err
130+ }
103131 }
104132
105- uncompressed , err := c .decoder .DecodeAll (compressed , nil )
106- if err != nil {
107- return nil , err
108- }
133+ wrapper := & decoderWrapper {Decoder : decoder }
134+ runtime .SetFinalizer (wrapper , func (dw * decoderWrapper ) {
135+ err := dw .Reset (nil )
136+ if err == nil {
137+ c .decoderPool .Put (dw .Decoder )
138+ }
139+ })
109140
110- return bytes . NewReader ( uncompressed ) , nil
141+ return wrapper , nil
111142}
112143
113144func (c * compressor ) Name () string {
0 commit comments