Skip to content

Commit e131124

Browse files
committed
GODRIVER-124: batch splitting implementation
Change-Id: I4371a92c1cd74769dd25ffac272da5ca018afa1d
1 parent 17dab82 commit e131124

File tree

5 files changed

+300
-36
lines changed

5 files changed

+300
-36
lines changed

core/command/insert.go

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ import (
1616
"github.com/mongodb/mongo-go-driver/core/wiremessage"
1717
)
1818

19+
// this is the amount of reserved buffer space in a message that the
20+
// driver reserves for command overhead.
21+
const reservedCommandBufferBytes = 16 * 10 * 10 * 10
22+
1923
// Insert represents the insert command.
2024
//
2125
// The insert command inserts a set of documents into the database.
@@ -27,24 +31,76 @@ type Insert struct {
2731
Docs []*bson.Document
2832
Opts []option.InsertOptioner
2933

30-
result result.Insert
31-
err error
34+
result result.Insert
35+
err error
36+
continueOnError bool
3237
}
3338

34-
// Encode will encode this command into a wire message for the given server description.
35-
func (i *Insert) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
39+
func (i *Insert) split(maxCount, targetBatchSize int) ([][]*bson.Document, error) {
40+
batches := [][]*bson.Document{}
41+
42+
if targetBatchSize > reservedCommandBufferBytes {
43+
targetBatchSize -= reservedCommandBufferBytes
44+
}
45+
46+
if maxCount <= 0 {
47+
maxCount = 1
48+
}
49+
50+
startAt := 0
51+
splitInserts:
52+
for {
53+
size := 0
54+
batch := []*bson.Document{}
55+
assembleBatch:
56+
for idx := startAt; idx < len(i.Docs); idx++ {
57+
itsize, err := i.Docs[idx].Validate()
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
if size+int(itsize) > targetBatchSize {
63+
break assembleBatch
64+
}
65+
66+
size += int(itsize)
67+
batch = append(batch, i.Docs[idx])
68+
startAt++
69+
if len(batch) == maxCount {
70+
break assembleBatch
71+
}
72+
}
73+
batches = append(batches, batch)
74+
if startAt == len(i.Docs) {
75+
break splitInserts
76+
}
77+
}
78+
79+
return batches, nil
80+
}
81+
82+
func (i *Insert) encodeBatch(docs []*bson.Document, desc description.SelectedServer) (wiremessage.WireMessage, error) {
83+
3684
command := bson.NewDocument(bson.EC.String("insert", i.NS.Collection))
37-
vals := make([]*bson.Value, 0, len(i.Docs))
38-
for _, doc := range i.Docs {
85+
86+
vals := make([]*bson.Value, 0, len(docs))
87+
for _, doc := range docs {
3988
vals = append(vals, bson.VC.Document(doc))
4089
}
4190
command.Append(bson.EC.ArrayFromElements("documents", vals...))
4291

43-
for _, option := range i.Opts {
44-
if option == nil {
92+
for _, opt := range i.Opts {
93+
if opt == nil {
4594
continue
4695
}
47-
err := option.Option(command)
96+
97+
if ordered, ok := opt.(option.OptOrdered); ok {
98+
if !ordered {
99+
i.continueOnError = true
100+
}
101+
}
102+
103+
err := opt.Option(command)
48104
if err != nil {
49105
return nil, err
50106
}
@@ -53,6 +109,26 @@ func (i *Insert) Encode(desc description.SelectedServer) (wiremessage.WireMessag
53109
return (&Command{DB: i.NS.DB, Command: command, isWrite: true}).Encode(desc)
54110
}
55111

112+
// Encode will encode this command into a wire message for the given server description.
113+
func (i *Insert) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
114+
out := []wiremessage.WireMessage{}
115+
batches, err := i.split(int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
for _, docs := range batches {
121+
cmd, err := i.encodeBatch(docs, desc)
122+
if err != nil {
123+
return nil, err
124+
}
125+
126+
out = append(out, cmd)
127+
}
128+
129+
return out, nil
130+
}
131+
56132
// Decode will decode the wire message using the provided server description. Errors during decoding
57133
// are deferred until either the Result or Err methods are called.
58134
func (i *Insert) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Insert {
@@ -79,18 +155,40 @@ func (i *Insert) Err() error { return i.err }
79155

80156
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
81157
func (i *Insert) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Insert, error) {
82-
wm, err := i.Encode(desc)
83-
if err != nil {
84-
return result.Insert{}, err
85-
}
158+
res := result.Insert{}
86159

87-
err = rw.WriteWireMessage(ctx, wm)
160+
wms, err := i.Encode(desc)
88161
if err != nil {
89-
return result.Insert{}, err
162+
return res, err
90163
}
91-
wm, err = rw.ReadWireMessage(ctx)
92-
if err != nil {
93-
return result.Insert{}, err
164+
165+
for _, wm := range wms {
166+
err = rw.WriteWireMessage(ctx, wm)
167+
if err != nil {
168+
return res, err
169+
}
170+
wm, err = rw.ReadWireMessage(ctx)
171+
if err != nil {
172+
return res, err
173+
}
174+
175+
r, err := i.Decode(desc, wm).Result()
176+
if err != nil {
177+
return res, err
178+
}
179+
180+
res.WriteErrors = append(res.WriteErrors, r.WriteErrors...)
181+
182+
if r.WriteConcernError != nil {
183+
res.WriteConcernError = r.WriteConcernError
184+
}
185+
186+
if !i.continueOnError && len(res.WriteErrors) > 0 {
187+
return res, nil
188+
}
189+
190+
res.N += r.N
94191
}
95-
return i.Decode(desc, wm).Result()
192+
193+
return res, nil
96194
}

core/command/insert_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package command
2+
3+
import (
4+
"testing"
5+
6+
"github.com/mongodb/mongo-go-driver/bson"
7+
"github.com/mongodb/mongo-go-driver/core/description"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestInsertCommandSplitting(t *testing.T) {
12+
const (
13+
megabyte = 10 * 10 * 10 * 10 * 10 * 10
14+
kilobyte = 10 * 10 * 10
15+
)
16+
17+
ss := description.SelectedServer{}
18+
t.Run("split_smoke_test", func(t *testing.T) {
19+
i := &Insert{}
20+
for n := 0; n < 100; n++ {
21+
i.Docs = append(i.Docs, bson.NewDocument(bson.EC.Int32("a", int32(n))))
22+
}
23+
24+
batches, err := i.split(10, kilobyte) // 1kb
25+
assert.NoError(t, err)
26+
assert.Len(t, batches, 10)
27+
for _, b := range batches {
28+
assert.Len(t, b, 10)
29+
wm, err := i.encodeBatch(b, ss)
30+
assert.NoError(t, err)
31+
assert.True(t, wm.Len() < 16*megabyte)
32+
}
33+
})
34+
t.Run("split_with_small_target_Size", func(t *testing.T) {
35+
i := &Insert{}
36+
for n := 0; n < 100; n++ {
37+
i.Docs = append(i.Docs, bson.NewDocument(bson.EC.Int32("a", int32(n))))
38+
}
39+
40+
batches, err := i.split(100, 32) // 32 bytes?
41+
assert.NoError(t, err)
42+
assert.Len(t, batches, 50)
43+
for _, b := range batches {
44+
assert.Len(t, b, 2)
45+
wm, err := i.encodeBatch(b, ss)
46+
assert.NoError(t, err)
47+
assert.True(t, wm.Len() < 16*megabyte)
48+
}
49+
})
50+
t.Run("invalid_max_counts", func(t *testing.T) {
51+
i := &Insert{}
52+
for n := 0; n < 100; n++ {
53+
i.Docs = append(i.Docs, bson.NewDocument(bson.EC.Int32("a", int32(n))))
54+
}
55+
56+
for _, ct := range []int{-1, 0, -1000} {
57+
batches, err := i.split(ct, 100*megabyte)
58+
assert.NoError(t, err)
59+
assert.Len(t, batches, 100)
60+
for _, b := range batches {
61+
assert.Len(t, b, 1)
62+
wm, err := i.encodeBatch(b, ss)
63+
assert.NoError(t, err)
64+
assert.True(t, wm.Len() < 16*megabyte)
65+
}
66+
}
67+
68+
})
69+
}

core/description/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type Server struct {
3939
LastError error
4040
LastUpdateTime time.Time
4141
LastWriteTime time.Time
42-
MaxBatchCount uint16
42+
MaxBatchCount uint32
4343
MaxDocumentSize uint32
4444
MaxMessageSize uint32
4545
Members []address.Address

core/result/result.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ type IsMaster struct {
9090
LastWriteTimestamp time.Time `bson:"lastWriteDate,omitempty"`
9191
MaxBSONObjectSize uint32 `bson:"maxBsonObjectSize,omitempty"`
9292
MaxMessageSizeBytes uint32 `bson:"maxMessageSizeBytes,omitempty"`
93-
MaxWriteBatchSize uint16 `bson:"maxWriteBatchSize,omitempty"`
93+
MaxWriteBatchSize uint32 `bson:"maxWriteBatchSize,omitempty"`
9494
Me string `bson:"me,omitempty"`
9595
MaxWireVersion int32 `bson:"maxWireVersion,omitempty"`
9696
MinWireVersion int32 `bson:"minWireVersion,omitempty"`

0 commit comments

Comments
 (0)