Skip to content

Commit 27a230b

Browse files
authored
Godriver-2259 Reduce memory consumption. (#985)
1 parent 95de0fb commit 27a230b

File tree

5 files changed

+97
-14
lines changed

5 files changed

+97
-14
lines changed

x/mongo/driver/compression.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"compress/zlib"
1212
"fmt"
1313
"io"
14+
"sync"
1415

1516
"github.com/golang/snappy"
1617
"github.com/klauspost/compress/zstd"
@@ -25,6 +26,21 @@ type CompressionOpts struct {
2526
UncompressedSize int32
2627
}
2728

29+
var zstdEncoders = &sync.Map{}
30+
31+
func getZstdEncoder(l zstd.EncoderLevel) (*zstd.Encoder, error) {
32+
v, ok := zstdEncoders.Load(l)
33+
if ok {
34+
return v.(*zstd.Encoder), nil
35+
}
36+
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(l))
37+
if err != nil {
38+
return nil, err
39+
}
40+
zstdEncoders.Store(l, encoder)
41+
return encoder, nil
42+
}
43+
2844
// CompressPayload takes a byte slice and compresses it according to the options passed
2945
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
3046
switch opts.Compressor {
@@ -48,21 +64,11 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
4864
}
4965
return b.Bytes(), nil
5066
case wiremessage.CompressorZstd:
51-
var b bytes.Buffer
52-
w, err := zstd.NewWriter(&b, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
53-
if err != nil {
54-
return nil, err
55-
}
56-
_, err = io.Copy(w, bytes.NewBuffer(in))
67+
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
5768
if err != nil {
58-
_ = w.Close()
5969
return nil, err
6070
}
61-
err = w.Close()
62-
if err != nil {
63-
return nil, err
64-
}
65-
return b.Bytes(), nil
71+
return encoder.EncodeAll(in, nil), nil
6672
default:
6773
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
6874
}

x/mongo/driver/compression_test.go

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
package driver
88

99
import (
10-
"strconv"
10+
"os"
1111
"testing"
1212

1313
"github.com/stretchr/testify/assert"
@@ -23,7 +23,7 @@ func TestCompression(t *testing.T) {
2323
}
2424

2525
for _, compressor := range compressors {
26-
t.Run(strconv.Itoa(int(compressor)), func(t *testing.T) {
26+
t.Run(compressor.String(), func(t *testing.T) {
2727
payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt")
2828
opts := CompressionOpts{
2929
Compressor: compressor,
@@ -40,3 +40,41 @@ func TestCompression(t *testing.T) {
4040
})
4141
}
4242
}
43+
44+
func BenchmarkCompressPayload(b *testing.B) {
45+
payload := func() []byte {
46+
buf, err := os.ReadFile("compression.go")
47+
if err != nil {
48+
b.Log(err)
49+
b.FailNow()
50+
}
51+
for i := 1; i < 10; i++ {
52+
buf = append(buf, buf...)
53+
}
54+
return buf
55+
}()
56+
57+
compressors := []wiremessage.CompressorID{
58+
wiremessage.CompressorSnappy,
59+
wiremessage.CompressorZLib,
60+
wiremessage.CompressorZstd,
61+
}
62+
63+
for _, compressor := range compressors {
64+
b.Run(compressor.String(), func(b *testing.B) {
65+
opts := CompressionOpts{
66+
Compressor: compressor,
67+
ZlibLevel: wiremessage.DefaultZlibLevel,
68+
ZstdLevel: wiremessage.DefaultZstdLevel,
69+
}
70+
b.RunParallel(func(pb *testing.PB) {
71+
for pb.Next() {
72+
_, err := CompressPayload(payload, opts)
73+
if err != nil {
74+
b.Error(err)
75+
}
76+
}
77+
})
78+
})
79+
}
80+
}

x/mongo/driver/topology/connection.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
638638
return dst, ErrConnectionClosed
639639
}
640640
if c.connection.compressor == wiremessage.CompressorNoOp {
641+
if len(dst) == 0 {
642+
return src, nil
643+
}
641644
return append(dst, src...), nil
642645
}
643646
_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)

x/mongo/driver/topology/connection_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"crypto/tls"
1212
"errors"
13+
"math/rand"
1314
"net"
1415
"sync"
1516
"sync/atomic"
@@ -22,6 +23,7 @@ import (
2223
"go.mongodb.org/mongo-driver/mongo/address"
2324
"go.mongodb.org/mongo-driver/mongo/description"
2425
"go.mongodb.org/mongo-driver/x/mongo/driver"
26+
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
2527
)
2628

2729
type testHandshaker struct {
@@ -926,6 +928,24 @@ func TestConnection(t *testing.T) {
926928
})
927929
}
928930

931+
func BenchmarkConnection(b *testing.B) {
932+
b.Run("CompressWireMessage CompressorNoOp", func(b *testing.B) {
933+
buf := make([]byte, 256)
934+
_, err := rand.Read(buf)
935+
if err != nil {
936+
b.Log(err)
937+
b.FailNow()
938+
}
939+
conn := Connection{connection: &connection{compressor: wiremessage.CompressorNoOp}}
940+
for i := 0; i < b.N; i++ {
941+
_, err := conn.CompressWireMessage(buf, nil)
942+
if err != nil {
943+
b.Error(err)
944+
}
945+
}
946+
})
947+
}
948+
929949
// cancellationTestNetConn is a net.Conn implementation that is used to test context.Cancellation during an in-progress
930950
// network read or write. This type has two unbuffered channels: operationStartedChan and continueChan. When Read/Write
931951
// starts, the type will write to operationStartedChan, which will block until the test reads from it. This signals to

x/mongo/driver/wiremessage/wiremessage.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ const (
188188
CompressorZstd
189189
)
190190

191+
// String implements the fmt.Stringer interface.
192+
func (id CompressorID) String() string {
193+
switch id {
194+
case CompressorNoOp:
195+
return "CompressorNoOp"
196+
case CompressorSnappy:
197+
return "CompressorSnappy"
198+
case CompressorZLib:
199+
return "CompressorZLib"
200+
case CompressorZstd:
201+
return "CompressorZstd"
202+
default:
203+
return "CompressorInvalid"
204+
}
205+
}
206+
191207
const (
192208
// DefaultZlibLevel is the default level for zlib compression
193209
DefaultZlibLevel = 6

0 commit comments

Comments
 (0)