Skip to content

Commit 97d5b6b

Browse files
committed
WIP
1 parent 63a973b commit 97d5b6b

File tree

4 files changed

+67
-62
lines changed

4 files changed

+67
-62
lines changed

mongo/client_bulk_write.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,16 @@ func (mb *modelBatches) IsOrdered() *bool {
161161

162162
func (mb *modelBatches) AdvanceBatches(n int) {
163163
mb.offset += n
164+
if mb.offset > len(mb.models) {
165+
mb.offset = len(mb.models)
166+
}
164167
}
165168

166-
func (mb *modelBatches) End() bool {
167-
return len(mb.models) <= mb.offset
169+
func (mb *modelBatches) Size() int {
170+
if mb.offset > len(mb.models) {
171+
return 0
172+
}
173+
return len(mb.models) - mb.offset
168174
}
169175

170176
func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
@@ -208,7 +214,7 @@ type functionSet struct {
208214
}
209215

210216
func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
211-
if mb.End() {
217+
if mb.Size() == 0 {
212218
return 0, dst, io.EOF
213219
}
214220

mongo/integration/client_side_encryption_prose_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
470470
cpt.cseStarted = cpt.cseStarted[:0]
471471
_, err = cpt.cseColl.InsertMany(context.Background(), []interface{}{firstBulkDoc, secondBulkDoc})
472472
assert.Nil(mt, err, "InsertMany error for large documents: %v", err)
473-
assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d", len(cpt.cseStarted))
473+
assert.Equal(mt, 2, len(cpt.cseStarted), "expected 2 insert events, got %d with size %d %d", len(cpt.cseStarted), len(str), len(limitsDoc))
474474

475475
// insert a document slightly smaller than 16MiB and expect the operation to succeed
476476
doc = bson.D{{"_id", "under_16mib"}, {"unencrypted", complete16mbStr[:maxBsonObjSize-2000]}}

x/mongo/driver/batches.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type Batches struct {
2525
}
2626

2727
func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
28-
if b.End() {
28+
if b.Size() == 0 {
2929
return 0, dst, io.EOF
3030
}
3131
l := len(dst)
@@ -34,7 +34,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
3434
idx, dst = bsoncore.ReserveLength(dst)
3535
dst = append(dst, b.Identifier...)
3636
dst = append(dst, 0x00)
37-
size := len(dst) - l
37+
var size int
3838
var n int
3939
for i := b.offset; i < len(b.Documents); i++ {
4040
if n == maxCount {
@@ -45,7 +45,7 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
4545
break
4646
}
4747
size += len(doc)
48-
if size >= totalSize {
48+
if size > maxDocSize {
4949
break
5050
}
5151
dst = append(dst, doc...)
@@ -59,12 +59,12 @@ func (b *Batches) AppendBatchSequence(dst []byte, maxCount, maxDocSize, totalSiz
5959
}
6060

6161
func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize int) (int, []byte, error) {
62-
if b.End() {
62+
if b.Size() == 0 {
6363
return 0, dst, io.EOF
6464
}
6565
l := len(dst)
6666
aidx, dst := bsoncore.AppendArrayElementStart(dst, b.Identifier)
67-
size := len(dst) - l
67+
var size int
6868
var n int
6969
for i := b.offset; i < len(b.Documents); i++ {
7070
if n == maxCount {
@@ -75,7 +75,7 @@ func (b *Batches) AppendBatchArray(dst []byte, maxCount, maxDocSize, totalSize i
7575
break
7676
}
7777
size += len(doc)
78-
if size >= totalSize {
78+
if size > totalSize {
7979
break
8080
}
8181
dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(n), doc)
@@ -98,8 +98,14 @@ func (b *Batches) IsOrdered() *bool {
9898

9999
func (b *Batches) AdvanceBatches(n int) {
100100
b.offset += n
101+
if b.offset > len(b.Documents) {
102+
b.offset = len(b.Documents)
103+
}
101104
}
102105

103-
func (b *Batches) End() bool {
104-
return len(b.Documents) <= b.offset
106+
func (b *Batches) Size() int {
107+
if b.offset > len(b.Documents) {
108+
return 0
109+
}
110+
return len(b.Documents) - b.offset
105111
}

x/mongo/driver/operation.go

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ type Operation struct {
284284
AppendBatchArray(dst []byte, maxCount int, maxDocSize int, totalSize int) (int, []byte, error)
285285
IsOrdered() *bool
286286
AdvanceBatches(n int)
287-
End() bool
287+
Size() int
288288
}
289289

290290
// Legacy sets the legacy type for this operation. There are only 3 types that require legacy
@@ -719,8 +719,9 @@ func (op Operation) Execute(ctx context.Context) error {
719719

720720
desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}
721721

722+
var moreToCome bool
722723
var startedInfo startedInformation
723-
*wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
724+
*wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
724725

725726
if err != nil {
726727
return err
@@ -746,9 +747,6 @@ func (op Operation) Execute(ctx context.Context) error {
746747

747748
op.publishStartedEvent(ctx, startedInfo)
748749

749-
// get the moreToCome flag information before we compress
750-
moreToCome := wiremessage.IsMsgMoreToCome(*wm)
751-
752750
// compress wiremessage if allowed
753751
if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
754752
b := memoryPool.Get().(*[]byte)
@@ -872,15 +870,14 @@ func (op Operation) Execute(ctx context.Context) error {
872870
// }
873871
}
874872

875-
if op.Batches != nil && len(tt.WriteErrors) > 0 && currIndex > 0 {
876-
for i := range tt.WriteErrors {
877-
tt.WriteErrors[i].Index += int64(currIndex)
878-
}
879-
}
880-
881873
// If batching is enabled and either ordered is the default (which is true) or
882874
// explicitly set to true and we have write errors, return the errors.
883875
if op.Batches != nil && len(tt.WriteErrors) > 0 {
876+
if currIndex > 0 {
877+
for i := range tt.WriteErrors {
878+
tt.WriteErrors[i].Index += int64(currIndex)
879+
}
880+
}
884881
if isOrdered := op.Batches.IsOrdered(); isOrdered == nil || *isOrdered {
885882
return tt
886883
}
@@ -1015,7 +1012,6 @@ func (op Operation) Execute(ctx context.Context) error {
10151012
}
10161013
perr := op.ProcessResponseFn(ctx, res, info)
10171014
if perr != nil {
1018-
fmt.Println("op", perr)
10191015
return perr
10201016
}
10211017
}
@@ -1036,7 +1032,7 @@ func (op Operation) Execute(ctx context.Context) error {
10361032
// If we're batching and there are batches remaining, advance to the next batch. This isn't
10371033
// a retry, so increment the transaction number, reset the retries number, and don't set
10381034
// server or connection to nil to continue using the same connection.
1039-
if op.Batches != nil {
1035+
if op.Batches != nil && op.Batches.Size() > startedInfo.processedBatches {
10401036
// If retries are supported for the current operation on the current server description,
10411037
// the session isn't nil, and client retries are enabled, increment the txn number.
10421038
// Calling IncrementTxnNumber() for server descriptions or topologies that do not
@@ -1053,7 +1049,7 @@ func (op Operation) Execute(ctx context.Context) error {
10531049
}
10541050
currIndex += startedInfo.processedBatches
10551051
op.Batches.AdvanceBatches(startedInfo.processedBatches)
1056-
if !op.Batches.End() {
1052+
if op.Batches.Size() > 0 {
10571053
continue
10581054
}
10591055
}
@@ -1289,21 +1285,11 @@ func (op Operation) createMsgWireMessage(
12891285
cmdFn func([]byte, description.SelectedServer) ([]byte, error),
12901286
) ([]byte, []byte, error) {
12911287
var flags wiremessage.MsgFlag
1292-
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
1293-
// aren't batching or we are encoding the last batch.
1294-
var batching bool
1295-
if op.Batches != nil && !op.Batches.End() {
1296-
batching = true
1297-
}
1298-
if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && !batching {
1299-
flags = wiremessage.MoreToCome
1300-
}
13011288
// Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
13021289
// respond with the MoreToCome flag and then stream responses over this connection.
13031290
if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
1304-
flags |= wiremessage.ExhaustAllowed
1291+
flags = wiremessage.ExhaustAllowed
13051292
}
1306-
13071293
dst = wiremessage.AppendMsgFlags(dst, flags)
13081294
// Body
13091295
dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument)
@@ -1365,11 +1351,12 @@ func (op Operation) createWireMessage(
13651351
desc description.SelectedServer,
13661352
conn Connection,
13671353
requestID int32,
1368-
) ([]byte, startedInformation, error) {
1354+
) ([]byte, bool, startedInformation, error) {
13691355
var info startedInformation
13701356
var wmindex int32
13711357
var err error
13721358

1359+
fIdx := len(dst)
13731360
isLegacy := isLegacyHandshake(op, desc)
13741361
shouldEncrypt := op.shouldEncrypt()
13751362
if !isLegacy && !shouldEncrypt {
@@ -1395,23 +1382,11 @@ func (op Operation) createWireMessage(
13951382
}
13961383
} else if shouldEncrypt {
13971384
if desc.WireVersion.Max < cryptMinWireVersion {
1398-
return dst, info, errors.New("auto-encryption requires a MongoDB version of 4.2")
1385+
return dst, false, info, errors.New("auto-encryption requires a MongoDB version of 4.2")
13991386
}
14001387
cmdFn := func(dst []byte, desc description.SelectedServer) ([]byte, error) {
1401-
// create temporary command document
1402-
var cmdDst []byte
1403-
info.processedBatches, cmdDst, err = op.addEncryptCommandFields(nil, desc)
1404-
if err != nil {
1405-
return nil, err
1406-
}
1407-
// encrypt the command
1408-
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
1409-
if err != nil {
1410-
return nil, err
1411-
}
1412-
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
1413-
dst = append(dst, encrypted[4:len(encrypted)-1]...)
1414-
return dst, nil
1388+
info.processedBatches, dst, err = op.addEncryptCommandFields(ctx, dst, desc)
1389+
return dst, err
14151390
}
14161391
wmindex, dst = wiremessage.AppendHeaderStart(dst, requestID, 0, wiremessage.OpMsg)
14171392
dst, info.cmd, err = op.createMsgWireMessage(maxTimeMS, dst, desc, conn, cmdFn)
@@ -1425,32 +1400,43 @@ func (op Operation) createWireMessage(
14251400
dst, info.cmd, err = op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc, cmdFn)
14261401
}
14271402
if err != nil {
1428-
return nil, info, err
1403+
return nil, false, info, err
1404+
}
1405+
1406+
var moreToCome bool
1407+
// We set the MoreToCome bit if we have a write concern, it's unacknowledged, and we either
1408+
// aren't batching or we are encoding the last batch.
1409+
unacknowledged := op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern)
1410+
batching := op.Batches != nil && op.Batches.Size() > info.processedBatches
1411+
if !isLegacy && unacknowledged && !batching {
1412+
dst[fIdx] |= byte(wiremessage.MoreToCome)
1413+
moreToCome = true
14291414
}
14301415
info.requestID = requestID
1431-
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
1416+
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), moreToCome, info, nil
14321417
}
14331418

1434-
func (op Operation) addEncryptCommandFields(dst []byte, desc description.SelectedServer) (int, []byte, error) {
1435-
var idx int32
1436-
idx, dst = bsoncore.AppendDocumentStart(dst)
1419+
func (op Operation) addEncryptCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) (int, []byte, error) {
1420+
idx, cmdDst := bsoncore.AppendDocumentStart(nil)
14371421
var err error
1438-
dst, err = op.CommandFn(dst, desc)
1422+
// create temporary command document
1423+
cmdDst, err = op.CommandFn(cmdDst, desc)
14391424
if err != nil {
14401425
return 0, nil, err
14411426
}
14421427
var n int
14431428
if op.Batches != nil {
14441429
maxBatchCount := int(desc.MaxBatchCount)
14451430
maxDocumentSize := int(desc.MaxDocumentSize)
1431+
fmt.Println("addEncryptCommandFields", cryptMaxBsonObjectSize, maxDocumentSize)
14461432
if maxBatchCount > 1 {
1447-
n, dst, err = op.Batches.AppendBatchArray(dst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize)
1433+
n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, maxBatchCount, cryptMaxBsonObjectSize, maxDocumentSize)
14481434
if err != nil {
14491435
return 0, nil, err
14501436
}
14511437
}
14521438
if n == 0 {
1453-
n, dst, err = op.Batches.AppendBatchArray(dst, 1, maxDocumentSize, maxDocumentSize)
1439+
n, cmdDst, err = op.Batches.AppendBatchArray(cmdDst, 1, maxDocumentSize, maxDocumentSize)
14541440
if err != nil {
14551441
return 0, nil, err
14561442
}
@@ -1459,10 +1445,17 @@ func (op Operation) addEncryptCommandFields(dst []byte, desc description.Selecte
14591445
}
14601446
}
14611447
}
1462-
dst, err = bsoncore.AppendDocumentEnd(dst, idx)
1448+
cmdDst, err = bsoncore.AppendDocumentEnd(cmdDst, idx)
1449+
if err != nil {
1450+
return 0, nil, err
1451+
}
1452+
// encrypt the command
1453+
encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
14631454
if err != nil {
14641455
return 0, nil, err
14651456
}
1457+
// append encrypted command to original destination, removing the first 4 bytes (length) and final byte (terminator)
1458+
dst = append(dst, encrypted[4:len(encrypted)-1]...)
14661459
return n, dst, nil
14671460
}
14681461

0 commit comments

Comments
 (0)