@@ -865,7 +865,7 @@ func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) (
865
865
}
866
866
867
867
func (op Operation ) readWireMessage (ctx context.Context , conn Connection ) (result []byte , err error ) {
868
- wm , err := conn .ReadWireMessage (ctx , nil )
868
+ wm , err := conn .ReadWireMessage (ctx )
869
869
if err != nil {
870
870
return nil , op .networkError (err )
871
871
}
@@ -876,14 +876,21 @@ func (op Operation) readWireMessage(ctx context.Context, conn Connection) (resul
876
876
streamer .SetStreaming (wiremessage .IsMsgMoreToCome (wm ))
877
877
}
878
878
879
- // decompress wiremessage
880
- wm , err = op .decompressWireMessage (wm )
881
- if err != nil {
882
- return nil , err
879
+ length , _ , _ , opcode , rem , ok := wiremessage .ReadHeader (wm )
880
+ if ! ok || len (wm ) < int (length ) {
881
+ return nil , errors .New ("malformed wire message: insufficient bytes" )
882
+ }
883
+ if opcode == wiremessage .OpCompressed {
884
+ rawsize := length - 16 // remove header size
885
+ // decompress wiremessage
886
+ opcode , rem , err = op .decompressWireMessage (rem [:rawsize ])
887
+ if err != nil {
888
+ return nil , err
889
+ }
883
890
}
884
891
885
892
// decode
886
- res , err := op .decodeResult (wm )
893
+ res , err := op .decodeResult (opcode , rem )
887
894
// Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating
888
895
// everything.
889
896
op .updateClusterTimes (res )
@@ -940,51 +947,39 @@ func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, w
940
947
return bsoncore .BuildDocument (nil , bsoncore .AppendInt32Element (nil , "ok" , 1 )), err
941
948
}
942
949
943
- // decompressWireMessage handles decompressing a wiremessage. If the wiremessage
944
- // is not compressed, this method will return the wiremessage.
945
- func (Operation ) decompressWireMessage (wm []byte ) ([]byte , error ) {
946
- // read the header and ensure this is a compressed wire message
947
- length , reqid , respto , opcode , rem , ok := wiremessage .ReadHeader (wm )
948
- if ! ok || len (wm ) < int (length ) {
949
- return nil , errors .New ("malformed wire message: insufficient bytes" )
950
- }
951
- if opcode != wiremessage .OpCompressed {
952
- return wm , nil
953
- }
950
+ // decompressWireMessage handles decompressing a wiremessage without the header.
951
+ func (Operation ) decompressWireMessage (wm []byte ) (wiremessage.OpCode , []byte , error ) {
954
952
// get the original opcode and uncompressed size
955
- opcode , rem , ok = wiremessage .ReadCompressedOriginalOpCode (rem )
953
+ opcode , rem , ok : = wiremessage .ReadCompressedOriginalOpCode (wm )
956
954
if ! ok {
957
- return nil , errors .New ("malformed OP_COMPRESSED: missing original opcode" )
955
+ return 0 , nil , errors .New ("malformed OP_COMPRESSED: missing original opcode" )
958
956
}
959
957
uncompressedSize , rem , ok := wiremessage .ReadCompressedUncompressedSize (rem )
960
958
if ! ok {
961
- return nil , errors .New ("malformed OP_COMPRESSED: missing uncompressed size" )
959
+ return 0 , nil , errors .New ("malformed OP_COMPRESSED: missing uncompressed size" )
962
960
}
963
961
// get the compressor ID and decompress the message
964
962
compressorID , rem , ok := wiremessage .ReadCompressedCompressorID (rem )
965
963
if ! ok {
966
- return nil , errors .New ("malformed OP_COMPRESSED: missing compressor ID" )
964
+ return 0 , nil , errors .New ("malformed OP_COMPRESSED: missing compressor ID" )
967
965
}
968
- compressedSize := length - 25 // header (16) + original opcode (4) + uncompressed size (4) + compressor ID (1)
966
+ compressedSize := len ( wm ) - 9 // original opcode (4) + uncompressed size (4) + compressor ID (1)
969
967
// return the original wiremessage
970
- msg , _ , ok := wiremessage .ReadCompressedCompressedMessage (rem , compressedSize )
968
+ msg , _ , ok := wiremessage .ReadCompressedCompressedMessage (rem , int32 ( compressedSize ) )
971
969
if ! ok {
972
- return nil , errors .New ("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage" )
970
+ return 0 , nil , errors .New ("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage" )
973
971
}
974
972
975
- wm = make ([]byte , 0 , int (uncompressedSize )+ 16 )
976
- wm = wiremessage .AppendHeader (wm , uncompressedSize + 16 , reqid , respto , opcode )
977
973
opts := CompressionOpts {
978
974
Compressor : compressorID ,
979
975
UncompressedSize : uncompressedSize ,
980
976
}
981
977
uncompressed , err := DecompressPayload (msg , opts )
982
978
if err != nil {
983
- return nil , err
979
+ return 0 , nil , err
984
980
}
985
- wm = append (wm , uncompressed ... )
986
981
987
- return wm , nil
982
+ return opcode , uncompressed , nil
988
983
}
989
984
990
985
func (op Operation ) createWireMessage (
@@ -1541,28 +1536,12 @@ func (Operation) canCompress(cmd string) bool {
1541
1536
}
1542
1537
1543
1538
// decodeOpReply extracts the necessary information from an OP_REPLY wire message.
1544
- // includesHeader: specifies whether or not wm includes the message header
1545
1539
// Returns the decoded OP_REPLY. If the err field of the returned opReply is non-nil, an error occurred while decoding
1546
1540
// or validating the response and the other fields are undefined.
1547
- func (Operation ) decodeOpReply (wm []byte , includesHeader bool ) opReply {
1541
+ func (Operation ) decodeOpReply (wm []byte ) opReply {
1548
1542
var reply opReply
1549
1543
var ok bool
1550
1544
1551
- if includesHeader {
1552
- wmLength := len (wm )
1553
- var length int32
1554
- var opcode wiremessage.OpCode
1555
- length , _ , _ , opcode , wm , ok = wiremessage .ReadHeader (wm )
1556
- if ! ok || int (length ) > wmLength {
1557
- reply .err = errors .New ("malformed wire message: insufficient bytes" )
1558
- return reply
1559
- }
1560
- if opcode != wiremessage .OpReply {
1561
- reply .err = errors .New ("malformed wire message: incorrect opcode" )
1562
- return reply
1563
- }
1564
- }
1565
-
1566
1545
reply .responseFlags , wm , ok = wiremessage .ReadReplyFlags (wm )
1567
1546
if ! ok {
1568
1547
reply .err = errors .New ("malformed OP_REPLY: missing flags" )
@@ -1583,7 +1562,7 @@ func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
1583
1562
reply .err = errors .New ("malformed OP_REPLY: missing numberReturned" )
1584
1563
return reply
1585
1564
}
1586
- reply .documents , wm , ok = wiremessage .ReadReplyDocuments (wm )
1565
+ reply .documents , _ , ok = wiremessage .ReadReplyDocuments (wm )
1587
1566
if ! ok {
1588
1567
reply .err = errors .New ("malformed OP_REPLY: could not read documents from reply" )
1589
1568
}
@@ -1607,18 +1586,10 @@ func (Operation) decodeOpReply(wm []byte, includesHeader bool) opReply {
1607
1586
return reply
1608
1587
}
1609
1588
1610
- func (op Operation ) decodeResult (wm []byte ) (bsoncore.Document , error ) {
1611
- wmLength := len (wm )
1612
- length , _ , _ , opcode , wm , ok := wiremessage .ReadHeader (wm )
1613
- if ! ok || int (length ) > wmLength {
1614
- return nil , errors .New ("malformed wire message: insufficient bytes" )
1615
- }
1616
-
1617
- wm = wm [:wmLength - 16 ] // constrain to just this wiremessage, incase there are multiple in the slice
1618
-
1589
+ func (op Operation ) decodeResult (opcode wiremessage.OpCode , wm []byte ) (bsoncore.Document , error ) {
1619
1590
switch opcode {
1620
1591
case wiremessage .OpReply :
1621
- reply := op .decodeOpReply (wm , false )
1592
+ reply := op .decodeOpReply (wm )
1622
1593
if reply .err != nil {
1623
1594
return nil , reply .err
1624
1595
}
@@ -1635,7 +1606,7 @@ func (op Operation) decodeResult(wm []byte) (bsoncore.Document, error) {
1635
1606
1636
1607
return rdr , ExtractErrorFromServerResponse (rdr )
1637
1608
case wiremessage .OpMsg :
1638
- _ , wm , ok = wiremessage .ReadMsgFlags (wm )
1609
+ _ , wm , ok : = wiremessage .ReadMsgFlags (wm )
1639
1610
if ! ok {
1640
1611
return nil , errors .New ("malformed wire message: missing OP_MSG flags" )
1641
1612
}
0 commit comments