Skip to content

Commit b6bdeda

Browse files
qingyang-huprestonvasquez
authored andcommitted
GODRIVER-2778 Reduce memory usage on compressed loads. (#1200)
1 parent 169f337 commit b6bdeda

File tree

10 files changed

+116
-125
lines changed

10 files changed

+116
-125
lines changed

mongo/integration/mtest/opmsg_deployment.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ func (c *connection) WriteWireMessage(_ context.Context, wm []byte) error {
5757
}
5858

5959
// ReadWireMessage returns the next response in the connection's list of responses.
60-
func (c *connection) ReadWireMessage(_ context.Context, dst []byte) ([]byte, error) {
60+
func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) {
61+
var dst []byte
6162
if len(c.responses) == 0 {
6263
return dst, errors.New("no responses remaining")
6364
}

mongo/integration/mtest/wiremessage_helpers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func parseOpCompressed(wm []byte) (wiremessage.OpCode, []byte, error) {
5454
return originalOpcode, nil, errors.New("failed to read compressor ID")
5555
}
5656

57-
compressedMsg, wm, ok := wiremessage.ReadCompressedCompressedMessage(wm, int32(len(wm)))
57+
compressedMsg, _, ok := wiremessage.ReadCompressedCompressedMessage(wm, int32(len(wm)))
5858
if !ok {
5959
return originalOpcode, nil, errors.New("failed to read compressed message")
6060
}

x/mongo/driver/compression.go

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,61 @@ type CompressionOpts struct {
2626
UncompressedSize int32
2727
}
2828

29-
var zstdEncoders = &sync.Map{}
29+
var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
3030

31-
func getZstdEncoder(l zstd.EncoderLevel) (*zstd.Encoder, error) {
32-
v, ok := zstdEncoders.Load(l)
33-
if ok {
31+
func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
32+
if v, ok := zstdEncoders.Load(level); ok {
3433
return v.(*zstd.Encoder), nil
3534
}
36-
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(l))
35+
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
3736
if err != nil {
3837
return nil, err
3938
}
40-
zstdEncoders.Store(l, encoder)
39+
zstdEncoders.Store(level, encoder)
4140
return encoder, nil
4241
}
4342

43+
var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
44+
45+
func getZlibEncoder(level int) (*zlibEncoder, error) {
46+
if v, ok := zlibEncoders.Load(level); ok {
47+
return v.(*zlibEncoder), nil
48+
}
49+
writer, err := zlib.NewWriterLevel(nil, level)
50+
if err != nil {
51+
return nil, err
52+
}
53+
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
54+
zlibEncoders.Store(level, encoder)
55+
56+
return encoder, nil
57+
}
58+
59+
type zlibEncoder struct {
60+
mu sync.Mutex
61+
writer *zlib.Writer
62+
buf *bytes.Buffer
63+
}
64+
65+
func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
66+
e.mu.Lock()
67+
defer e.mu.Unlock()
68+
69+
e.buf.Reset()
70+
e.writer.Reset(e.buf)
71+
72+
_, err := e.writer.Write(src)
73+
if err != nil {
74+
return nil, err
75+
}
76+
err = e.writer.Close()
77+
if err != nil {
78+
return nil, err
79+
}
80+
dst = append(dst[:0], e.buf.Bytes()...)
81+
return dst, nil
82+
}
83+
4484
// CompressPayload takes a byte slice and compresses it according to the options passed
4585
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
4686
switch opts.Compressor {
@@ -49,20 +89,11 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
4989
case wiremessage.CompressorSnappy:
5090
return snappy.Encode(nil, in), nil
5191
case wiremessage.CompressorZLib:
52-
var b bytes.Buffer
53-
w, err := zlib.NewWriterLevel(&b, opts.ZlibLevel)
92+
encoder, err := getZlibEncoder(opts.ZlibLevel)
5493
if err != nil {
5594
return nil, err
5695
}
57-
_, err = w.Write(in)
58-
if err != nil {
59-
return nil, err
60-
}
61-
err = w.Close()
62-
if err != nil {
63-
return nil, err
64-
}
65-
return b.Bytes(), nil
96+
return encoder.Encode(nil, in)
6697
case wiremessage.CompressorZstd:
6798
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
6899
if err != nil {
@@ -75,20 +106,23 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
75106
}
76107

77108
// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
78-
func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
109+
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
79110
switch opts.Compressor {
80111
case wiremessage.CompressorNoOp:
81112
return in, nil
82113
case wiremessage.CompressorSnappy:
83-
uncompressed := make([]byte, opts.UncompressedSize)
114+
uncompressed = make([]byte, opts.UncompressedSize)
84115
return snappy.Decode(uncompressed, in)
85116
case wiremessage.CompressorZLib:
86-
decompressor, err := zlib.NewReader(bytes.NewReader(in))
117+
r, err := zlib.NewReader(bytes.NewReader(in))
87118
if err != nil {
88119
return nil, err
89120
}
90-
uncompressed := make([]byte, opts.UncompressedSize)
91-
_, err = io.ReadFull(decompressor, uncompressed)
121+
defer func() {
122+
err = r.Close()
123+
}()
124+
uncompressed = make([]byte, opts.UncompressedSize)
125+
_, err = io.ReadFull(r, uncompressed)
92126
if err != nil {
93127
return nil, err
94128
}
@@ -99,7 +133,7 @@ func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
99133
return nil, err
100134
}
101135
defer r.Close()
102-
uncompressed := make([]byte, opts.UncompressedSize)
136+
uncompressed = make([]byte, opts.UncompressedSize)
103137
_, err = io.ReadFull(r, uncompressed)
104138
if err != nil {
105139
return nil, err

x/mongo/driver/driver.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ type Server interface {
5959
// Connection represents a connection to a MongoDB server.
6060
type Connection interface {
6161
WriteWireMessage(context.Context, []byte) error
62-
ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error)
62+
ReadWireMessage(ctx context.Context) ([]byte, error)
6363
Description() description.Server
6464

6565
// Close closes any underlying connection and returns or frees any resources held by the

x/mongo/driver/drivertest/channel_conn.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,15 @@ func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error {
4040
}
4141

4242
// ReadWireMessage implements the driver.Connection interface.
43-
func (c *ChannelConn) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
44-
dst = dst[:0]
43+
func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) {
4544
var wm []byte
4645
var err error
4746
select {
4847
case wm = <-c.ReadResp:
4948
case err = <-c.ReadErr:
5049
case <-ctx.Done():
5150
}
52-
if l := len(wm); l > 0 {
53-
if l > cap(dst) {
54-
dst = make([]byte, 0, l)
55-
}
56-
dst = append(dst, wm...)
57-
}
58-
return dst, err
51+
return wm, err
5952
}
6053

6154
// Description implements the driver.Connection interface.

x/mongo/driver/operation.go

Lines changed: 29 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) (
865865
}
866866

867867
func (op Operation) readWireMessage(ctx context.Context, conn Connection) (result []byte, err error) {
868-
wm, err := conn.ReadWireMessage(ctx, nil)
868+
wm, err := conn.ReadWireMessage(ctx)
869869
if err != nil {
870870
return nil, op.networkError(err)
871871
}
@@ -876,14 +876,21 @@ func (op Operation) readWireMessage(ctx context.Context, conn Connection) (resul
876876
streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm))
877877
}
878878

879-
// decompress wiremessage
880-
wm, err = op.decompressWireMessage(wm)
881-
if err != nil {
882-
return nil, err
879+
length, _, _, opcode, rem, ok := wiremessage.ReadHeader(wm)
880+
if !ok || len(wm) < int(length) {
881+
return nil, errors.New("malformed wire message: insufficient bytes")
882+
}
883+
if opcode == wiremessage.OpCompressed {
884+
rawsize := length - 16 // remove header size
885+
// decompress wiremessage
886+
opcode, rem, err = op.decompressWireMessage(rem[:rawsize])
887+
if err != nil {
888+
return nil, err
889+
}
883890
}
884891

885892
// decode
886-
res, err := op.decodeResult(wm)
893+
res, err := op.decodeResult(opcode, rem)
887894
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
888895
// everything.
889896
op.updateClusterTimes(res)
@@ -940,51 +947,39 @@ func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, w
940947
return bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)), err
941948
}
942949

943-
// decompressWireMessage handles decompressing a wiremessage. If the wiremessage
944-
// is not compressed, this method will return the wiremessage.
945-
func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
946-
// read the header and ensure this is a compressed wire message
947-
length, reqid, respto, opcode, rem, ok := wiremessage.ReadHeader(wm)
948-
if !ok || len(wm) < int(length) {
949-
return nil, errors.New("malformed wire message: insufficient bytes")
950-
}
951-
if opcode != wiremessage.OpCompressed {
952-
return wm, nil
953-
}
950+
// decompressWireMessage handles decompressing a wiremessage without the header.
951+
func (Operation) decompressWireMessage(wm []byte) (wiremessage.OpCode, []byte, error) {
954952
// get the original opcode and uncompressed size
955-
opcode, rem, ok = wiremessage.ReadCompressedOriginalOpCode(rem)
953+
opcode, rem, ok := wiremessage.ReadCompressedOriginalOpCode(wm)
956954
if !ok {
957-
return nil, errors.New("malformed OP_COMPRESSED: missing original opcode")
955+
return 0, nil, errors.New("malformed OP_COMPRESSED: missing original opcode")
958956
}
959957
uncompressedSize, rem, ok := wiremessage.ReadCompressedUncompressedSize(rem)
960958
if !ok {
961-
return nil, errors.New("malformed OP_COMPRESSED: missing uncompressed size")
959+
return 0, nil, errors.New("malformed OP_COMPRESSED: missing uncompressed size")
962960
}
963961
// get the compressor ID and decompress the message
964962
compressorID, rem, ok := wiremessage.ReadCompressedCompressorID(rem)
965963
if !ok {
966-
return nil, errors.New("malformed OP_COMPRESSED: missing compressor ID")
964+
return 0, nil, errors.New("malformed OP_COMPRESSED: missing compressor ID")
967965
}
968-
compressedSize := length - 25 // header (16) + original opcode (4) + uncompressed size (4) + compressor ID (1)
966+
compressedSize := len(wm) - 9 // original opcode (4) + uncompressed size (4) + compressor ID (1)
969967
// return the original wiremessage
970-
msg, _, ok := wiremessage.ReadCompressedCompressedMessage(rem, compressedSize)
968+
msg, _, ok := wiremessage.ReadCompressedCompressedMessage(rem, int32(compressedSize))
971969
if !ok {
972-
return nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage")
970+
return 0, nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage")
973971
}
974972

975-
wm = make([]byte, 0, int(uncompressedSize)+16)
976-
wm = wiremessage.AppendHeader(wm, uncompressedSize+16, reqid, respto, opcode)
977973
opts := CompressionOpts{
978974
Compressor: compressorID,
979975
UncompressedSize: uncompressedSize,
980976
}
981977
uncompressed, err := DecompressPayload(msg, opts)
982978
if err != nil {
983-
return nil, err
979+
return 0, nil, err
984980
}
985-
wm = append(wm, uncompressed...)
986981

987-
return wm, nil
982+
return opcode, uncompressed, nil
988983
}
989984

990985
func (op Operation) createWireMessage(
@@ -1541,28 +1536,12 @@ func (Operation) canCompress(cmd string) bool {
15411536
}
15421537

15431538
// decodeOpReply extracts the necessary information from an OP_REPLY wire message.
1544-
// includesHeader: specifies whether or not wm includes the message header
15451539
// Returns the decoded OP_REPLY. If the err field of the returned opReply is non-nil, an error occurred while decoding
15461540
// or validating the response and the other fields are undefined.
1547-
func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
1541+
func (Operation) decodeOpReply(wm []byte) opReply {
15481542
var reply opReply
15491543
var ok bool
15501544

1551-
if includesHeader {
1552-
wmLength := len(wm)
1553-
var length int32
1554-
var opcode wiremessage.OpCode
1555-
length, _, _, opcode, wm, ok = wiremessage.ReadHeader(wm)
1556-
if !ok || int(length) > wmLength {
1557-
reply.err = errors.New("malformed wire message: insufficient bytes")
1558-
return reply
1559-
}
1560-
if opcode != wiremessage.OpReply {
1561-
reply.err = errors.New("malformed wire message: incorrect opcode")
1562-
return reply
1563-
}
1564-
}
1565-
15661545
reply.responseFlags, wm, ok = wiremessage.ReadReplyFlags(wm)
15671546
if !ok {
15681547
reply.err = errors.New("malformed OP_REPLY: missing flags")
@@ -1583,7 +1562,7 @@ func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
15831562
reply.err = errors.New("malformed OP_REPLY: missing numberReturned")
15841563
return reply
15851564
}
1586-
reply.documents, wm, ok = wiremessage.ReadReplyDocuments(wm)
1565+
reply.documents, _, ok = wiremessage.ReadReplyDocuments(wm)
15871566
if !ok {
15881567
reply.err = errors.New("malformed OP_REPLY: could not read documents from reply")
15891568
}
@@ -1607,18 +1586,10 @@ func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
16071586
return reply
16081587
}
16091588

1610-
func (op Operation) decodeResult(wm []byte) (bsoncore.Document, error) {
1611-
wmLength := len(wm)
1612-
length, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
1613-
if !ok || int(length) > wmLength {
1614-
return nil, errors.New("malformed wire message: insufficient bytes")
1615-
}
1616-
1617-
wm = wm[:wmLength-16] // constrain to just this wiremessage, incase there are multiple in the slice
1618-
1589+
func (op Operation) decodeResult(opcode wiremessage.OpCode, wm []byte) (bsoncore.Document, error) {
16191590
switch opcode {
16201591
case wiremessage.OpReply:
1621-
reply := op.decodeOpReply(wm, false)
1592+
reply := op.decodeOpReply(wm)
16221593
if reply.err != nil {
16231594
return nil, reply.err
16241595
}
@@ -1635,7 +1606,7 @@ func (op Operation) decodeResult(wm []byte) (bsoncore.Document, error) {
16351606

16361607
return rdr, ExtractErrorFromServerResponse(rdr)
16371608
case wiremessage.OpMsg:
1638-
_, wm, ok = wiremessage.ReadMsgFlags(wm)
1609+
_, wm, ok := wiremessage.ReadMsgFlags(wm)
16391610
if !ok {
16401611
return nil, errors.New("malformed wire message: missing OP_MSG flags")
16411612
}

x/mongo/driver/operation_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,6 @@ func (m *mockServerSelector) SelectServer(description.Topology, []description.Se
740740
type mockConnection struct {
741741
// parameters
742742
pWriteWM []byte
743-
pReadDst []byte
744743

745744
// returns
746745
rWriteErr error
@@ -770,8 +769,7 @@ func (m *mockConnection) WriteWireMessage(_ context.Context, wm []byte) error {
770769
return m.rWriteErr
771770
}
772771

773-
func (m *mockConnection) ReadWireMessage(_ context.Context, dst []byte) ([]byte, error) {
774-
m.pReadDst = dst
772+
func (m *mockConnection) ReadWireMessage(_ context.Context) ([]byte, error) {
775773
return m.rReadWM, m.rReadErr
776774
}
777775

x/mongo/driver/session/client_session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s TransactionState) String() string {
8282
type LoadBalancedTransactionConnection interface {
8383
// Functions copied over from driver.Connection.
8484
WriteWireMessage(context.Context, []byte) error
85-
ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error)
85+
ReadWireMessage(ctx context.Context) ([]byte, error)
8686
Description() description.Server
8787
Close() error
8888
ID() string

0 commit comments

Comments
 (0)