@@ -9,14 +9,24 @@ package command
9
9
import (
10
10
"errors"
11
11
12
+ "context"
13
+
12
14
"github.com/mongodb/mongo-go-driver/bson"
13
15
"github.com/mongodb/mongo-go-driver/core/description"
16
+ "github.com/mongodb/mongo-go-driver/core/option"
14
17
"github.com/mongodb/mongo-go-driver/core/readconcern"
18
+ "github.com/mongodb/mongo-go-driver/core/result"
15
19
"github.com/mongodb/mongo-go-driver/core/session"
16
20
"github.com/mongodb/mongo-go-driver/core/wiremessage"
17
21
"github.com/mongodb/mongo-go-driver/core/writeconcern"
18
22
)
19
23
24
+ // WriteBatch represents a single batch for a write operation.
25
+ type WriteBatch struct {
26
+ * Write
27
+ numDocs int
28
+ }
29
+
20
30
// DecodeError attempts to decode the wiremessage as an error
21
31
func DecodeError (wm wiremessage.WireMessage ) error {
22
32
var rdr bson.Reader
@@ -376,6 +386,266 @@ func opmsgCreateDocSequence(arr *bson.Array, identifier string) (wiremessage.Sec
376
386
return docSequence , nil
377
387
}
378
388
389
+ func splitBatches (docs []* bson.Document , maxCount , targetBatchSize int ) ([][]* bson.Document , error ) {
390
+ batches := [][]* bson.Document {}
391
+
392
+ if targetBatchSize > reservedCommandBufferBytes {
393
+ targetBatchSize -= reservedCommandBufferBytes
394
+ }
395
+
396
+ if maxCount <= 0 {
397
+ maxCount = 1
398
+ }
399
+
400
+ startAt := 0
401
+ splitInserts:
402
+ for {
403
+ size := 0
404
+ batch := []* bson.Document {}
405
+ assembleBatch:
406
+ for idx := startAt ; idx < len (docs ); idx ++ {
407
+ itsize , err := docs [idx ].Validate ()
408
+ if err != nil {
409
+ return nil , err
410
+ }
411
+
412
+ if int (itsize ) > targetBatchSize {
413
+ return nil , ErrDocumentTooLarge
414
+ }
415
+ if size + int (itsize ) > targetBatchSize {
416
+ break assembleBatch
417
+ }
418
+
419
+ size += int (itsize )
420
+ batch = append (batch , docs [idx ])
421
+ startAt ++
422
+ if len (batch ) == maxCount {
423
+ break assembleBatch
424
+ }
425
+ }
426
+ batches = append (batches , batch )
427
+ if startAt == len (docs ) {
428
+ break splitInserts
429
+ }
430
+ }
431
+
432
+ return batches , nil
433
+ }
434
+
435
+ func encodeBatch (
436
+ docs []* bson.Document ,
437
+ opts []option.Optioner ,
438
+ cmdKind WriteCommandKind ,
439
+ collName string ,
440
+ ) (* bson.Document , error ) {
441
+ var cmdName string
442
+ var docString string
443
+
444
+ switch cmdKind {
445
+ case InsertCommand :
446
+ cmdName = "insert"
447
+ docString = "documents"
448
+ case UpdateCommand :
449
+ cmdName = "update"
450
+ docString = "updates"
451
+ case DeleteCommand :
452
+ cmdName = "delete"
453
+ docString = "deletes"
454
+ }
455
+
456
+ cmd := bson .NewDocument (
457
+ bson .EC .String (cmdName , collName ),
458
+ )
459
+
460
+ vals := make ([]* bson.Value , 0 , len (docs ))
461
+ for _ , doc := range docs {
462
+ vals = append (vals , bson .VC .Document (doc ))
463
+ }
464
+ cmd .Append (bson .EC .ArrayFromElements (docString , vals ... ))
465
+
466
+ for _ , opt := range opts {
467
+ if opt == nil {
468
+ continue
469
+ }
470
+
471
+ err := opt .Option (cmd )
472
+ if err != nil {
473
+ return nil , err
474
+ }
475
+ }
476
+
477
+ return cmd , nil
478
+ }
479
+
480
+ // converts batches of Write Commands to wire messages
481
+ func batchesToWireMessage (batches []* WriteBatch , desc description.SelectedServer ) ([]wiremessage.WireMessage , error ) {
482
+ wms := make ([]wiremessage.WireMessage , len (batches ))
483
+ for _ , cmd := range batches {
484
+ wm , err := cmd .Encode (desc )
485
+ if err != nil {
486
+ return nil , err
487
+ }
488
+
489
+ wms = append (wms , wm )
490
+ }
491
+
492
+ return wms , nil
493
+ }
494
+
495
+ // Roundtrips the write batches, returning the result structs (as interface),
496
+ // the write batches that weren't round tripped and any errors
497
+ func roundTripBatches (
498
+ ctx context.Context ,
499
+ desc description.SelectedServer ,
500
+ rw wiremessage.ReadWriter ,
501
+ batches []* WriteBatch ,
502
+ continueOnError bool ,
503
+ sess * session.Client ,
504
+ cmdKind WriteCommandKind ,
505
+ ) (interface {}, []* WriteBatch , error ) {
506
+ var res interface {}
507
+ var upsertIndex int64 // the operation index for the upserted IDs map
508
+
509
+ // hold onto txnNumber, reset it when loop exits to ensure reuse of same
510
+ // transaction number if retry is needed
511
+ var txnNumber int64
512
+ if sess != nil && sess .RetryWrite {
513
+ txnNumber = sess .TxnNumber
514
+ }
515
+ for j , cmd := range batches {
516
+ rdr , err := cmd .RoundTrip (ctx , desc , rw )
517
+ if err != nil {
518
+ if sess != nil && sess .RetryWrite {
519
+ sess .TxnNumber = txnNumber + int64 (j )
520
+ }
521
+ return res , batches , err
522
+ }
523
+
524
+ // TODO can probably DRY up this code
525
+ switch cmdKind {
526
+ case InsertCommand :
527
+ if res == nil {
528
+ res = result.Insert {}
529
+ }
530
+
531
+ conv , _ := res .(result.Insert )
532
+ insertCmd := & Insert {}
533
+ r , err := insertCmd .decode (desc , rdr ).Result ()
534
+ if err != nil {
535
+ return res , batches , err
536
+ }
537
+
538
+ conv .WriteErrors = append (conv .WriteErrors , r .WriteErrors ... )
539
+
540
+ if r .WriteConcernError != nil {
541
+ conv .WriteConcernError = r .WriteConcernError
542
+ if sess != nil && sess .RetryWrite {
543
+ sess .TxnNumber = txnNumber
544
+ return conv , batches , nil // report writeconcernerror for retry
545
+ }
546
+ }
547
+
548
+ conv .N += r .N
549
+
550
+ if ! continueOnError && len (conv .WriteErrors ) > 0 {
551
+ return conv , batches , nil
552
+ }
553
+
554
+ res = conv
555
+ case UpdateCommand :
556
+ if res == nil {
557
+ res = result.Update {}
558
+ }
559
+
560
+ conv , _ := res .(result.Update )
561
+ updateCmd := & Update {}
562
+ r , err := updateCmd .decode (desc , rdr ).Result ()
563
+ if err != nil {
564
+ return conv , batches , err
565
+ }
566
+
567
+ conv .WriteErrors = append (conv .WriteErrors , r .WriteErrors ... )
568
+
569
+ if r .WriteConcernError != nil {
570
+ conv .WriteConcernError = r .WriteConcernError
571
+ if sess != nil && sess .RetryWrite {
572
+ sess .TxnNumber = txnNumber
573
+ return conv , batches , nil // report writeconcernerror for retry
574
+ }
575
+ }
576
+
577
+ conv .MatchedCount += r .MatchedCount
578
+ conv .ModifiedCount += r .ModifiedCount
579
+ for _ , upsert := range r .Upserted {
580
+ conv .Upserted = append (conv .Upserted , result.Upsert {
581
+ Index : upsert .Index + upsertIndex ,
582
+ ID : upsert .ID ,
583
+ })
584
+ }
585
+
586
+ if ! continueOnError && len (conv .WriteErrors ) > 0 {
587
+ return conv , batches , nil
588
+ }
589
+
590
+ res = conv
591
+ upsertIndex += int64 (cmd .numDocs )
592
+ case DeleteCommand :
593
+ if res == nil {
594
+ res = result.Delete {}
595
+ }
596
+
597
+ conv , _ := res .(result.Delete )
598
+ deleteCmd := & Delete {}
599
+ r , err := deleteCmd .decode (desc , rdr ).Result ()
600
+ if err != nil {
601
+ return conv , batches , err
602
+ }
603
+
604
+ conv .WriteErrors = append (conv .WriteErrors , r .WriteErrors ... )
605
+
606
+ if r .WriteConcernError != nil {
607
+ conv .WriteConcernError = r .WriteConcernError
608
+ if sess != nil && sess .RetryWrite {
609
+ sess .TxnNumber = txnNumber
610
+ return conv , batches , nil // report writeconcernerror for retry
611
+ }
612
+ }
613
+
614
+ conv .N += r .N
615
+
616
+ if ! continueOnError && len (conv .WriteErrors ) > 0 {
617
+ return conv , batches , nil
618
+ }
619
+
620
+ res = conv
621
+ }
622
+
623
+ // Increment txnNumber for each batch
624
+ if sess != nil && sess .RetryWrite {
625
+ sess .IncrementTxnNumber ()
626
+ batches = batches [1 :] // if batch encoded successfully, remove it from the slice
627
+ }
628
+ }
629
+
630
+ if sess != nil && sess .RetryWrite {
631
+ // if retryable write succeeded, transaction number will be incremented one extra time,
632
+ // so we decrement it here
633
+ sess .TxnNumber --
634
+ }
635
+
636
+ return res , batches , nil
637
+ }
638
+
379
639
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
380
640
// write concern.
381
641
var ErrUnacknowledgedWrite = errors .New ("unacknowledged write" )
642
+
643
+ // WriteCommandKind is the type of command represented by a Write
644
+ type WriteCommandKind int8
645
+
646
+ // These constants represent the valid types of write commands.
647
+ const (
648
+ InsertCommand WriteCommandKind = iota
649
+ UpdateCommand
650
+ DeleteCommand
651
+ )
0 commit comments