@@ -284,7 +284,7 @@ type Operation struct {
284284 AppendBatchArray (dst []byte , maxCount int , maxDocSize int , totalSize int ) (int , []byte , error )
285285 IsOrdered () * bool
286286 AdvanceBatches (n int )
287- End () bool
287+ Size () int
288288 }
289289
290290 // Legacy sets the legacy type for this operation. There are only 3 types that require legacy
@@ -719,8 +719,9 @@ func (op Operation) Execute(ctx context.Context) error {
719719
720720 desc := description.SelectedServer {Server : conn .Description (), Kind : op .Deployment .Kind ()}
721721
722+ var moreToCome bool
722723 var startedInfo startedInformation
723- * wm , startedInfo , err = op .createWireMessage (ctx , maxTimeMS , (* wm )[:0 ], desc , conn , requestID )
724+ * wm , moreToCome , startedInfo , err = op .createWireMessage (ctx , maxTimeMS , (* wm )[:0 ], desc , conn , requestID )
724725
725726 if err != nil {
726727 return err
@@ -746,9 +747,6 @@ func (op Operation) Execute(ctx context.Context) error {
746747
747748 op .publishStartedEvent (ctx , startedInfo )
748749
749- // get the moreToCome flag information before we compress
750- moreToCome := wiremessage .IsMsgMoreToCome (* wm )
751-
752750 // compress wiremessage if allowed
753751 if compressor , ok := conn .(Compressor ); ok && op .canCompress (startedInfo .cmdName ) {
754752 b := memoryPool .Get ().(* []byte )
@@ -872,15 +870,14 @@ func (op Operation) Execute(ctx context.Context) error {
872870 // }
873871 }
874872
875- if op .Batches != nil && len (tt .WriteErrors ) > 0 && currIndex > 0 {
876- for i := range tt .WriteErrors {
877- tt .WriteErrors [i ].Index += int64 (currIndex )
878- }
879- }
880-
881873 // If batching is enabled and either ordered is the default (which is true) or
882874 // explicitly set to true and we have write errors, return the errors.
883875 if op .Batches != nil && len (tt .WriteErrors ) > 0 {
876+ if currIndex > 0 {
877+ for i := range tt .WriteErrors {
878+ tt .WriteErrors [i ].Index += int64 (currIndex )
879+ }
880+ }
884881 if isOrdered := op .Batches .IsOrdered (); isOrdered == nil || * isOrdered {
885882 return tt
886883 }
@@ -1015,7 +1012,6 @@ func (op Operation) Execute(ctx context.Context) error {
10151012 }
10161013 perr := op .ProcessResponseFn (ctx , res , info )
10171014 if perr != nil {
1018- fmt .Println ("op" , perr )
10191015 return perr
10201016 }
10211017 }
@@ -1036,7 +1032,7 @@ func (op Operation) Execute(ctx context.Context) error {
10361032 // If we're batching and there are batches remaining, advance to the next batch. This isn't
10371033 // a retry, so increment the transaction number, reset the retries number, and don't set
10381034 // server or connection to nil to continue using the same connection.
1039- if op .Batches != nil {
1035+ if op .Batches != nil && op . Batches . Size () > startedInfo . processedBatches {
10401036 // If retries are supported for the current operation on the current server description,
10411037 // the session isn't nil, and client retries are enabled, increment the txn number.
10421038 // Calling IncrementTxnNumber() for server descriptions or topologies that do not
@@ -1053,7 +1049,7 @@ func (op Operation) Execute(ctx context.Context) error {
10531049 }
10541050 currIndex += startedInfo .processedBatches
10551051 op .Batches .AdvanceBatches (startedInfo .processedBatches )
1056- if ! op .Batches .End () {
1052+ if op .Batches .Size () > 0 {
10571053 continue
10581054 }
10591055 }
@@ -1289,21 +1285,11 @@ func (op Operation) createMsgWireMessage(
12891285 cmdFn func ([]byte , description.SelectedServer ) ([]byte , error ),
12901286) ([]byte , []byte , error ) {
12911287 var flags wiremessage.MsgFlag
1292- // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
1293- // aren't batching or we are encoding the last batch.
1294- var batching bool
1295- if op .Batches != nil && ! op .Batches .End () {
1296- batching = true
1297- }
1298- if op .WriteConcern != nil && ! writeconcern .AckWrite (op .WriteConcern ) && ! batching {
1299- flags = wiremessage .MoreToCome
1300- }
13011288 // Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
13021289 // respond with the MoreToCome flag and then stream responses over this connection.
13031290 if streamer , ok := conn .(StreamerConnection ); ok && streamer .SupportsStreaming () {
1304- flags | = wiremessage .ExhaustAllowed
1291+ flags = wiremessage .ExhaustAllowed
13051292 }
1306-
13071293 dst = wiremessage .AppendMsgFlags (dst , flags )
13081294 // Body
13091295 dst = wiremessage .AppendMsgSectionType (dst , wiremessage .SingleDocument )
@@ -1365,11 +1351,12 @@ func (op Operation) createWireMessage(
13651351 desc description.SelectedServer ,
13661352 conn Connection ,
13671353 requestID int32 ,
1368- ) ([]byte , startedInformation , error ) {
1354+ ) ([]byte , bool , startedInformation , error ) {
13691355 var info startedInformation
13701356 var wmindex int32
13711357 var err error
13721358
1359+ fIdx := len (dst )
13731360 isLegacy := isLegacyHandshake (op , desc )
13741361 shouldEncrypt := op .shouldEncrypt ()
13751362 if ! isLegacy && ! shouldEncrypt {
@@ -1395,23 +1382,11 @@ func (op Operation) createWireMessage(
13951382 }
13961383 } else if shouldEncrypt {
13971384 if desc .WireVersion .Max < cryptMinWireVersion {
1398- return dst , info , errors .New ("auto-encryption requires a MongoDB version of 4.2" )
1385+ return dst , false , info , errors .New ("auto-encryption requires a MongoDB version of 4.2" )
13991386 }
14001387 cmdFn := func (dst []byte , desc description.SelectedServer ) ([]byte , error ) {
1401- // create temporary command document
1402- var cmdDst []byte
1403- info .processedBatches , cmdDst , err = op .addEncryptCommandFields (nil , desc )
1404- if err != nil {
1405- return nil , err
1406- }
1407- // encrypt the command
1408- encrypted , err := op .Crypt .Encrypt (ctx , op .Database , cmdDst )
1409- if err != nil {
1410- return nil , err
1411- }
1412- // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
1413- dst = append (dst , encrypted [4 :len (encrypted )- 1 ]... )
1414- return dst , nil
1388+ info .processedBatches , dst , err = op .addEncryptCommandFields (ctx , dst , desc )
1389+ return dst , err
14151390 }
14161391 wmindex , dst = wiremessage .AppendHeaderStart (dst , requestID , 0 , wiremessage .OpMsg )
14171392 dst , info .cmd , err = op .createMsgWireMessage (maxTimeMS , dst , desc , conn , cmdFn )
@@ -1425,32 +1400,43 @@ func (op Operation) createWireMessage(
14251400 dst , info .cmd , err = op .createLegacyHandshakeWireMessage (maxTimeMS , dst , desc , cmdFn )
14261401 }
14271402 if err != nil {
1428- return nil , info , err
1403+ return nil , false , info , err
1404+ }
1405+
1406+ var moreToCome bool
1407+ // We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
1408+ // aren't batching or we are encoding the last batch.
1409+ unacknowledged := op .WriteConcern != nil && ! writeconcern .AckWrite (op .WriteConcern )
1410+ batching := op .Batches != nil && op .Batches .Size () > info .processedBatches
1411+ if ! isLegacy && unacknowledged && ! batching {
1412+ dst [fIdx ] |= byte (wiremessage .MoreToCome )
1413+ moreToCome = true
14291414 }
14301415 info .requestID = requestID
1431- return bsoncore .UpdateLength (dst , wmindex , int32 (len (dst [wmindex :]))), info , nil
1416+ return bsoncore .UpdateLength (dst , wmindex , int32 (len (dst [wmindex :]))), moreToCome , info , nil
14321417}
14331418
1434- func (op Operation ) addEncryptCommandFields (dst []byte , desc description.SelectedServer ) (int , []byte , error ) {
1435- var idx int32
1436- idx , dst = bsoncore .AppendDocumentStart (dst )
1419+ func (op Operation ) addEncryptCommandFields (ctx context.Context , dst []byte , desc description.SelectedServer ) (int , []byte , error ) {
1420+ idx , cmdDst := bsoncore .AppendDocumentStart (nil )
14371421 var err error
1438- dst , err = op .CommandFn (dst , desc )
1422+ // create temporary command document
1423+ cmdDst , err = op .CommandFn (cmdDst , desc )
14391424 if err != nil {
14401425 return 0 , nil , err
14411426 }
14421427 var n int
14431428 if op .Batches != nil {
14441429 maxBatchCount := int (desc .MaxBatchCount )
14451430 maxDocumentSize := int (desc .MaxDocumentSize )
1431+ fmt .Println ("addEncryptCommandFields" , cryptMaxBsonObjectSize , maxDocumentSize )
14461432 if maxBatchCount > 1 {
1447- n , dst , err = op .Batches .AppendBatchArray (dst , maxBatchCount , cryptMaxBsonObjectSize , maxDocumentSize )
1433+ n , cmdDst , err = op .Batches .AppendBatchArray (cmdDst , maxBatchCount , cryptMaxBsonObjectSize , maxDocumentSize )
14481434 if err != nil {
14491435 return 0 , nil , err
14501436 }
14511437 }
14521438 if n == 0 {
1453- n , dst , err = op .Batches .AppendBatchArray (dst , 1 , maxDocumentSize , maxDocumentSize )
1439+ n , cmdDst , err = op .Batches .AppendBatchArray (cmdDst , 1 , maxDocumentSize , maxDocumentSize )
14541440 if err != nil {
14551441 return 0 , nil , err
14561442 }
@@ -1459,10 +1445,17 @@ func (op Operation) addEncryptCommandFields(dst []byte, desc description.Selecte
14591445 }
14601446 }
14611447 }
1462- dst , err = bsoncore .AppendDocumentEnd (dst , idx )
1448+ cmdDst , err = bsoncore .AppendDocumentEnd (cmdDst , idx )
1449+ if err != nil {
1450+ return 0 , nil , err
1451+ }
1452+ // encrypt the command
1453+ encrypted , err := op .Crypt .Encrypt (ctx , op .Database , cmdDst )
14631454 if err != nil {
14641455 return 0 , nil , err
14651456 }
1457+ // append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
1458+ dst = append (dst , encrypted [4 :len (encrypted )- 1 ]... )
14661459 return n , dst , nil
14671460}
14681461
0 commit comments