Skip to content

Commit 7f8b1d0

Browse files
charlieviethmatthewdale
authored andcommitted
GODRIVER-2914 x/mongo/driver: enable parallel zlib compression and improve zstd decompression (#1320)
Co-authored-by: Matt Dale <[email protected]>
1 parent e7c7154 commit 7f8b1d0

File tree

3 files changed

+349
-53
lines changed

3 files changed

+349
-53
lines changed

x/mongo/driver/compression.go

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,48 +26,72 @@ type CompressionOpts struct {
2626
UncompressedSize int32
2727
}
2828

29-
var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
29+
// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
30+
// destination writer. It panics on any errors and should only be used at
31+
// package initialization time.
32+
func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
33+
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
34+
if err != nil {
35+
panic(err)
36+
}
37+
return enc
38+
}
39+
40+
var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
41+
0: nil, // zstd.speedNotSet
42+
zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest),
43+
zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault),
44+
zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
45+
zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression),
46+
}
3047

3148
func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
32-
if v, ok := zstdEncoders.Load(level); ok {
33-
return v.(*zstd.Encoder), nil
34-
}
35-
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
36-
if err != nil {
37-
return nil, err
49+
if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
50+
return zstdEncoders[level], nil
3851
}
39-
zstdEncoders.Store(level, encoder)
40-
return encoder, nil
52+
// The level is outside the expected range, return an error.
53+
return nil, fmt.Errorf("invalid zstd compression level: %d", level)
4154
}
4255

43-
var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
56+
// zlibEncodersOffset is the offset into the zlibEncoders array for a given
57+
// compression level.
58+
const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2
59+
60+
var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool
4461

4562
func getZlibEncoder(level int) (*zlibEncoder, error) {
46-
if v, ok := zlibEncoders.Load(level); ok {
47-
return v.(*zlibEncoder), nil
48-
}
49-
writer, err := zlib.NewWriterLevel(nil, level)
50-
if err != nil {
51-
return nil, err
63+
if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
64+
if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
65+
return enc, nil
66+
}
67+
writer, err := zlib.NewWriterLevel(nil, level)
68+
if err != nil {
69+
return nil, err
70+
}
71+
enc := &zlibEncoder{writer: writer, level: level}
72+
return enc, nil
5273
}
53-
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
54-
zlibEncoders.Store(level, encoder)
74+
// The level is outside the expected range, return an error.
75+
return nil, fmt.Errorf("invalid zlib compression level: %d", level)
76+
}
5577

56-
return encoder, nil
78+
func putZlibEncoder(enc *zlibEncoder) {
79+
if enc != nil {
80+
zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
81+
}
5782
}
5883

5984
type zlibEncoder struct {
60-
mu sync.Mutex
6185
writer *zlib.Writer
62-
buf *bytes.Buffer
86+
buf bytes.Buffer
87+
level int
6388
}
6489

6590
func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
66-
e.mu.Lock()
67-
defer e.mu.Unlock()
91+
defer putZlibEncoder(e)
6892

6993
e.buf.Reset()
70-
e.writer.Reset(e.buf)
94+
e.writer.Reset(&e.buf)
7195

7296
_, err := e.writer.Write(src)
7397
if err != nil {
@@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
105129
}
106130
}
107131

132+
var zstdReaderPool = sync.Pool{
133+
New: func() interface{} {
134+
r, _ := zstd.NewReader(nil)
135+
return r
136+
},
137+
}
138+
108139
// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
109-
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
140+
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
110141
switch opts.Compressor {
111142
case wiremessage.CompressorNoOp:
112143
return in, nil
@@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er
117148
} else if int32(l) != opts.UncompressedSize {
118149
return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
119150
}
120-
uncompressed = make([]byte, opts.UncompressedSize)
121-
return snappy.Decode(uncompressed, in)
151+
out := make([]byte, opts.UncompressedSize)
152+
return snappy.Decode(out, in)
122153
case wiremessage.CompressorZLib:
123154
r, err := zlib.NewReader(bytes.NewReader(in))
124155
if err != nil {
125156
return nil, err
126157
}
127-
defer func() {
128-
err = r.Close()
129-
}()
130-
uncompressed = make([]byte, opts.UncompressedSize)
131-
_, err = io.ReadFull(r, uncompressed)
132-
if err != nil {
158+
out := make([]byte, opts.UncompressedSize)
159+
if _, err := io.ReadFull(r, out); err != nil {
133160
return nil, err
134161
}
135-
return uncompressed, nil
136-
case wiremessage.CompressorZstd:
137-
r, err := zstd.NewReader(bytes.NewBuffer(in))
138-
if err != nil {
139-
return nil, err
140-
}
141-
defer r.Close()
142-
uncompressed = make([]byte, opts.UncompressedSize)
143-
_, err = io.ReadFull(r, uncompressed)
144-
if err != nil {
162+
if err := r.Close(); err != nil {
145163
return nil, err
146164
}
147-
return uncompressed, nil
165+
return out, nil
166+
case wiremessage.CompressorZstd:
167+
buf := make([]byte, 0, opts.UncompressedSize)
168+
// Using a pool here is about ~20% faster
169+
// than using a single global zstd.Reader
170+
r := zstdReaderPool.Get().(*zstd.Decoder)
171+
out, err := r.DecodeAll(in, buf)
172+
zstdReaderPool.Put(r)
173+
return out, err
148174
default:
149175
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
150176
}

x/mongo/driver/compression_test.go

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
package driver
88

99
import (
10+
"bytes"
11+
"compress/zlib"
1012
"os"
1113
"testing"
1214

15+
"github.com/golang/snappy"
16+
"github.com/klauspost/compress/zstd"
17+
1318
"go.mongodb.org/mongo-driver/internal/assert"
1419
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
1520
)
@@ -41,6 +46,43 @@ func TestCompression(t *testing.T) {
4146
}
4247
}
4348

49+
func TestCompressionLevels(t *testing.T) {
50+
in := []byte("abc")
51+
wr := new(bytes.Buffer)
52+
53+
t.Run("ZLib", func(t *testing.T) {
54+
opts := CompressionOpts{
55+
Compressor: wiremessage.CompressorZLib,
56+
}
57+
for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
58+
opts.ZlibLevel = lvl
59+
_, err1 := CompressPayload(in, opts)
60+
_, err2 := zlib.NewWriterLevel(wr, lvl)
61+
if err2 != nil {
62+
assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
63+
} else {
64+
assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
65+
}
66+
}
67+
})
68+
69+
t.Run("Zstd", func(t *testing.T) {
70+
opts := CompressionOpts{
71+
Compressor: wiremessage.CompressorZstd,
72+
}
73+
for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
74+
opts.ZstdLevel = int(lvl)
75+
_, err1 := CompressPayload(in, opts)
76+
_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
77+
if err2 != nil {
78+
assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
79+
} else {
80+
assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
81+
}
82+
}
83+
})
84+
}
85+
4486
func TestDecompressFailures(t *testing.T) {
4587
t.Parallel()
4688

@@ -62,18 +104,57 @@ func TestDecompressFailures(t *testing.T) {
62104
})
63105
}
64106

65-
func BenchmarkCompressPayload(b *testing.B) {
66-
payload := func() []byte {
67-
buf, err := os.ReadFile("compression.go")
107+
var (
108+
compressionPayload []byte
109+
compressedSnappyPayload []byte
110+
compressedZLibPayload []byte
111+
compressedZstdPayload []byte
112+
)
113+
114+
func initCompressionPayload(b *testing.B) {
115+
if compressionPayload != nil {
116+
return
117+
}
118+
data, err := os.ReadFile("testdata/compression.go")
119+
if err != nil {
120+
b.Fatal(err)
121+
}
122+
for i := 1; i < 10; i++ {
123+
data = append(data, data...)
124+
}
125+
compressionPayload = data
126+
127+
compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data)
128+
129+
{
130+
var buf bytes.Buffer
131+
enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault))
68132
if err != nil {
69-
b.Log(err)
70-
b.FailNow()
133+
b.Fatal(err)
71134
}
72-
for i := 1; i < 10; i++ {
73-
buf = append(buf, buf...)
135+
compressedZstdPayload = enc.EncodeAll(data, nil)
136+
}
137+
138+
{
139+
var buf bytes.Buffer
140+
enc := zlib.NewWriter(&buf)
141+
if _, err := enc.Write(data); err != nil {
142+
b.Fatal(err)
74143
}
75-
return buf
76-
}()
144+
if err := enc.Close(); err != nil {
145+
b.Fatal(err)
146+
}
147+
if err := enc.Close(); err != nil {
148+
b.Fatal(err)
149+
}
150+
compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...)
151+
}
152+
153+
b.ResetTimer()
154+
}
155+
156+
func BenchmarkCompressPayload(b *testing.B) {
157+
initCompressionPayload(b)
77158

78159
compressors := []wiremessage.CompressorID{
79160
wiremessage.CompressorSnappy,
@@ -88,6 +169,9 @@ func BenchmarkCompressPayload(b *testing.B) {
88169
ZlibLevel: wiremessage.DefaultZlibLevel,
89170
ZstdLevel: wiremessage.DefaultZstdLevel,
90171
}
172+
payload := compressionPayload
173+
b.SetBytes(int64(len(payload)))
174+
b.ReportAllocs()
91175
b.RunParallel(func(pb *testing.PB) {
92176
for pb.Next() {
93177
_, err := CompressPayload(payload, opts)
@@ -99,3 +183,38 @@ func BenchmarkCompressPayload(b *testing.B) {
99183
})
100184
}
101185
}
186+
187+
func BenchmarkDecompressPayload(b *testing.B) {
188+
initCompressionPayload(b)
189+
190+
benchmarks := []struct {
191+
compressor wiremessage.CompressorID
192+
payload []byte
193+
}{
194+
{wiremessage.CompressorSnappy, compressedSnappyPayload},
195+
{wiremessage.CompressorZLib, compressedZLibPayload},
196+
{wiremessage.CompressorZstd, compressedZstdPayload},
197+
}
198+
199+
for _, bench := range benchmarks {
200+
b.Run(bench.compressor.String(), func(b *testing.B) {
201+
opts := CompressionOpts{
202+
Compressor: bench.compressor,
203+
ZlibLevel: wiremessage.DefaultZlibLevel,
204+
ZstdLevel: wiremessage.DefaultZstdLevel,
205+
UncompressedSize: int32(len(compressionPayload)),
206+
}
207+
payload := bench.payload
208+
b.SetBytes(int64(len(compressionPayload)))
209+
b.ReportAllocs()
210+
b.RunParallel(func(pb *testing.PB) {
211+
for pb.Next() {
212+
_, err := DecompressPayload(payload, opts)
213+
if err != nil {
214+
b.Fatal(err)
215+
}
216+
}
217+
})
218+
})
219+
}
220+
}

0 commit comments

Comments
 (0)