Skip to content

Commit 4cb519e

Browse files
author
Divjot Arora
committed
Implement BulkWrite API
GODRIVER-123 Change-Id: Ica3dd7322d93efecaa953c6ff691dace6e184088
1 parent b5a8fcf commit 4cb519e

26 files changed

+3548
-270
lines changed

core/command/command.go

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,24 @@ package command
99
import (
1010
"errors"
1111

12+
"context"
13+
1214
"github.com/mongodb/mongo-go-driver/bson"
1315
"github.com/mongodb/mongo-go-driver/core/description"
16+
"github.com/mongodb/mongo-go-driver/core/option"
1417
"github.com/mongodb/mongo-go-driver/core/readconcern"
18+
"github.com/mongodb/mongo-go-driver/core/result"
1519
"github.com/mongodb/mongo-go-driver/core/session"
1620
"github.com/mongodb/mongo-go-driver/core/wiremessage"
1721
"github.com/mongodb/mongo-go-driver/core/writeconcern"
1822
)
1923

24+
// WriteBatch represents a single batch for a write operation.
25+
type WriteBatch struct {
26+
*Write
27+
numDocs int
28+
}
29+
2030
// DecodeError attempts to decode the wiremessage as an error
2131
func DecodeError(wm wiremessage.WireMessage) error {
2232
var rdr bson.Reader
@@ -376,6 +386,266 @@ func opmsgCreateDocSequence(arr *bson.Array, identifier string) (wiremessage.Sec
376386
return docSequence, nil
377387
}
378388

389+
func splitBatches(docs []*bson.Document, maxCount, targetBatchSize int) ([][]*bson.Document, error) {
390+
batches := [][]*bson.Document{}
391+
392+
if targetBatchSize > reservedCommandBufferBytes {
393+
targetBatchSize -= reservedCommandBufferBytes
394+
}
395+
396+
if maxCount <= 0 {
397+
maxCount = 1
398+
}
399+
400+
startAt := 0
401+
splitInserts:
402+
for {
403+
size := 0
404+
batch := []*bson.Document{}
405+
assembleBatch:
406+
for idx := startAt; idx < len(docs); idx++ {
407+
itsize, err := docs[idx].Validate()
408+
if err != nil {
409+
return nil, err
410+
}
411+
412+
if int(itsize) > targetBatchSize {
413+
return nil, ErrDocumentTooLarge
414+
}
415+
if size+int(itsize) > targetBatchSize {
416+
break assembleBatch
417+
}
418+
419+
size += int(itsize)
420+
batch = append(batch, docs[idx])
421+
startAt++
422+
if len(batch) == maxCount {
423+
break assembleBatch
424+
}
425+
}
426+
batches = append(batches, batch)
427+
if startAt == len(docs) {
428+
break splitInserts
429+
}
430+
}
431+
432+
return batches, nil
433+
}
434+
435+
func encodeBatch(
436+
docs []*bson.Document,
437+
opts []option.Optioner,
438+
cmdKind WriteCommandKind,
439+
collName string,
440+
) (*bson.Document, error) {
441+
var cmdName string
442+
var docString string
443+
444+
switch cmdKind {
445+
case InsertCommand:
446+
cmdName = "insert"
447+
docString = "documents"
448+
case UpdateCommand:
449+
cmdName = "update"
450+
docString = "updates"
451+
case DeleteCommand:
452+
cmdName = "delete"
453+
docString = "deletes"
454+
}
455+
456+
cmd := bson.NewDocument(
457+
bson.EC.String(cmdName, collName),
458+
)
459+
460+
vals := make([]*bson.Value, 0, len(docs))
461+
for _, doc := range docs {
462+
vals = append(vals, bson.VC.Document(doc))
463+
}
464+
cmd.Append(bson.EC.ArrayFromElements(docString, vals...))
465+
466+
for _, opt := range opts {
467+
if opt == nil {
468+
continue
469+
}
470+
471+
err := opt.Option(cmd)
472+
if err != nil {
473+
return nil, err
474+
}
475+
}
476+
477+
return cmd, nil
478+
}
479+
480+
// converts batches of Write Commands to wire messages
481+
func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
482+
wms := make([]wiremessage.WireMessage, len(batches))
483+
for _, cmd := range batches {
484+
wm, err := cmd.Encode(desc)
485+
if err != nil {
486+
return nil, err
487+
}
488+
489+
wms = append(wms, wm)
490+
}
491+
492+
return wms, nil
493+
}
494+
495+
// Roundtrips the write batches, returning the result structs (as interface),
496+
// the write batches that weren't round tripped and any errors
497+
func roundTripBatches(
498+
ctx context.Context,
499+
desc description.SelectedServer,
500+
rw wiremessage.ReadWriter,
501+
batches []*WriteBatch,
502+
continueOnError bool,
503+
sess *session.Client,
504+
cmdKind WriteCommandKind,
505+
) (interface{}, []*WriteBatch, error) {
506+
var res interface{}
507+
var upsertIndex int64 // the operation index for the upserted IDs map
508+
509+
// hold onto txnNumber, reset it when loop exits to ensure reuse of same
510+
// transaction number if retry is needed
511+
var txnNumber int64
512+
if sess != nil && sess.RetryWrite {
513+
txnNumber = sess.TxnNumber
514+
}
515+
for j, cmd := range batches {
516+
rdr, err := cmd.RoundTrip(ctx, desc, rw)
517+
if err != nil {
518+
if sess != nil && sess.RetryWrite {
519+
sess.TxnNumber = txnNumber + int64(j)
520+
}
521+
return res, batches, err
522+
}
523+
524+
// TODO can probably DRY up this code
525+
switch cmdKind {
526+
case InsertCommand:
527+
if res == nil {
528+
res = result.Insert{}
529+
}
530+
531+
conv, _ := res.(result.Insert)
532+
insertCmd := &Insert{}
533+
r, err := insertCmd.decode(desc, rdr).Result()
534+
if err != nil {
535+
return res, batches, err
536+
}
537+
538+
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
539+
540+
if r.WriteConcernError != nil {
541+
conv.WriteConcernError = r.WriteConcernError
542+
if sess != nil && sess.RetryWrite {
543+
sess.TxnNumber = txnNumber
544+
return conv, batches, nil // report writeconcernerror for retry
545+
}
546+
}
547+
548+
conv.N += r.N
549+
550+
if !continueOnError && len(conv.WriteErrors) > 0 {
551+
return conv, batches, nil
552+
}
553+
554+
res = conv
555+
case UpdateCommand:
556+
if res == nil {
557+
res = result.Update{}
558+
}
559+
560+
conv, _ := res.(result.Update)
561+
updateCmd := &Update{}
562+
r, err := updateCmd.decode(desc, rdr).Result()
563+
if err != nil {
564+
return conv, batches, err
565+
}
566+
567+
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
568+
569+
if r.WriteConcernError != nil {
570+
conv.WriteConcernError = r.WriteConcernError
571+
if sess != nil && sess.RetryWrite {
572+
sess.TxnNumber = txnNumber
573+
return conv, batches, nil // report writeconcernerror for retry
574+
}
575+
}
576+
577+
conv.MatchedCount += r.MatchedCount
578+
conv.ModifiedCount += r.ModifiedCount
579+
for _, upsert := range r.Upserted {
580+
conv.Upserted = append(conv.Upserted, result.Upsert{
581+
Index: upsert.Index + upsertIndex,
582+
ID: upsert.ID,
583+
})
584+
}
585+
586+
if !continueOnError && len(conv.WriteErrors) > 0 {
587+
return conv, batches, nil
588+
}
589+
590+
res = conv
591+
upsertIndex += int64(cmd.numDocs)
592+
case DeleteCommand:
593+
if res == nil {
594+
res = result.Delete{}
595+
}
596+
597+
conv, _ := res.(result.Delete)
598+
deleteCmd := &Delete{}
599+
r, err := deleteCmd.decode(desc, rdr).Result()
600+
if err != nil {
601+
return conv, batches, err
602+
}
603+
604+
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
605+
606+
if r.WriteConcernError != nil {
607+
conv.WriteConcernError = r.WriteConcernError
608+
if sess != nil && sess.RetryWrite {
609+
sess.TxnNumber = txnNumber
610+
return conv, batches, nil // report writeconcernerror for retry
611+
}
612+
}
613+
614+
conv.N += r.N
615+
616+
if !continueOnError && len(conv.WriteErrors) > 0 {
617+
return conv, batches, nil
618+
}
619+
620+
res = conv
621+
}
622+
623+
// Increment txnNumber for each batch
624+
if sess != nil && sess.RetryWrite {
625+
sess.IncrementTxnNumber()
626+
batches = batches[1:] // if batch encoded successfully, remove it from the slice
627+
}
628+
}
629+
630+
if sess != nil && sess.RetryWrite {
631+
// if retryable write succeeded, transaction number will be incremented one extra time,
632+
// so we decrement it here
633+
sess.TxnNumber--
634+
}
635+
636+
return res, batches, nil
637+
}
638+
379639
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
380640
// write concern.
381641
var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
642+
643+
// WriteCommandKind is the type of command represented by a Write
644+
type WriteCommandKind int8
645+
646+
// These constants represent the valid types of write commands.
647+
const (
648+
InsertCommand WriteCommandKind = iota
649+
UpdateCommand
650+
DeleteCommand
651+
)

0 commit comments

Comments
 (0)