Skip to content

Commit b8309ff

Browse files
committed
WIP
1 parent 45d22f8 commit b8309ff

File tree

9 files changed

+191
-184
lines changed

9 files changed

+191
-184
lines changed

mongo/client_bulk_write.go

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727

2828
// bulkWrite performs a bulkwrite operation
2929
type clientBulkWrite struct {
30-
models []interface{}
30+
models []clientWriteModel
3131
errorsOnly bool
3232
ordered *bool
3333
bypassDocumentValidation *bool
@@ -45,12 +45,17 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
4545
if len(bw.models) == 0 {
4646
return errors.New("empty write models")
4747
}
48+
retryMode := driver.RetryNone
49+
if bw.client.retryWrites {
50+
retryMode = driver.RetryOncePerCommand
51+
}
4852
batches := &modelBatches{
49-
session: bw.session,
50-
client: bw.client,
51-
ordered: bw.ordered,
52-
models: bw.models,
53-
result: &bw.result,
53+
session: bw.session,
54+
client: bw.client,
55+
ordered: bw.ordered,
56+
models: bw.models,
57+
result: &bw.result,
58+
retryMode: retryMode,
5459
}
5560
err := driver.Operation{
5661
CommandFn: bw.newCommand(),
@@ -142,7 +147,7 @@ type modelBatches struct {
142147
client *Client
143148

144149
ordered *bool
145-
models []interface{}
150+
models []clientWriteModel
146151

147152
offset int
148153

@@ -222,17 +227,14 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
222227
mb.newIDMap = make(map[int]interface{})
223228

224229
nsMap := make(map[string]int)
225-
getNsIndex := func(namespace string) (int, bsoncore.Document) {
226-
idx, doc := bsoncore.AppendDocumentStart(nil)
227-
doc = bsoncore.AppendStringElement(doc, "ns", namespace)
228-
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
229-
230-
if v, ok := nsMap[namespace]; ok {
231-
return v, doc
230+
getNsIndex := func(namespace string) (int, bool) {
231+
v, ok := nsMap[namespace]
232+
if ok {
233+
return v, ok
232234
}
233235
nsIdx := len(nsMap)
234236
nsMap[namespace] = nsIdx
235-
return nsIdx, doc
237+
return nsIdx, ok
236238
}
237239

238240
canRetry := true
@@ -249,12 +251,13 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
249251
break
250252
}
251253

252-
var nsIdx int
253-
var ns, doc bsoncore.Document
254+
ns := mb.models[i].namespace
255+
nsIdx, exists := getNsIndex(ns)
256+
257+
var doc bsoncore.Document
254258
var err error
255-
switch model := mb.models[i].(type) {
259+
switch model := mb.models[i].model.(type) {
256260
case *ClientInsertOneModel:
257-
nsIdx, ns = getNsIndex(model.Namespace)
258261
mb.cursorHandlers[i] = mb.appendInsertResult
259262
var id interface{}
260263
id, doc, err = (&clientInsertDoc{
@@ -266,7 +269,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
266269
}
267270
mb.newIDMap[i] = id
268271
case *ClientUpdateOneModel:
269-
nsIdx, ns = getNsIndex(model.Namespace)
270272
mb.cursorHandlers[i] = mb.appendUpdateResult
271273
doc, err = (&clientUpdateDoc{
272274
namespace: nsIdx,
@@ -281,7 +283,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
281283
}).marshal(mb.client.bsonOpts, mb.client.registry)
282284
case *ClientUpdateManyModel:
283285
canRetry = false
284-
nsIdx, ns = getNsIndex(model.Namespace)
285286
mb.cursorHandlers[i] = mb.appendUpdateResult
286287
doc, err = (&clientUpdateDoc{
287288
namespace: nsIdx,
@@ -295,7 +296,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
295296
checkDollarKey: true,
296297
}).marshal(mb.client.bsonOpts, mb.client.registry)
297298
case *ClientReplaceOneModel:
298-
nsIdx, ns = getNsIndex(model.Namespace)
299299
mb.cursorHandlers[i] = mb.appendUpdateResult
300300
doc, err = (&clientUpdateDoc{
301301
namespace: nsIdx,
@@ -309,7 +309,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
309309
checkDollarKey: false,
310310
}).marshal(mb.client.bsonOpts, mb.client.registry)
311311
case *ClientDeleteOneModel:
312-
nsIdx, ns = getNsIndex(model.Namespace)
313312
mb.cursorHandlers[i] = mb.appendDeleteResult
314313
doc, err = (&clientDeleteDoc{
315314
namespace: nsIdx,
@@ -320,7 +319,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
320319
}).marshal(mb.client.bsonOpts, mb.client.registry)
321320
case *ClientDeleteManyModel:
322321
canRetry = false
323-
nsIdx, ns = getNsIndex(model.Namespace)
324322
mb.cursorHandlers[i] = mb.appendDeleteResult
325323
doc, err = (&clientDeleteDoc{
326324
namespace: nsIdx,
@@ -343,7 +341,12 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
343341
}
344342

345343
dst = fn.appendDocument(dst, strconv.Itoa(n), doc)
346-
nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), ns)
344+
if !exists {
345+
idx, doc := bsoncore.AppendDocumentStart(nil)
346+
doc = bsoncore.AppendStringElement(doc, "ns", ns)
347+
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
348+
nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc)
349+
}
347350
n++
348351
}
349352
if n == 0 {
@@ -430,7 +433,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
430433
if int(cur.Idx) >= len(mb.cursorHandlers) {
431434
continue
432435
}
433-
ok = ok && mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current)
436+
ok = mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) && ok
434437
}
435438
err = cursor.Err()
436439
if err != nil {
@@ -456,32 +459,51 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
456459
}
457460

458461
func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool {
462+
if err := cur.extractError(); err != nil {
463+
err.Raw = raw
464+
if mb.writeErrors == nil {
465+
mb.writeErrors = make(map[int]WriteError)
466+
}
467+
mb.writeErrors[int(cur.Idx)] = *err
468+
return false
469+
}
470+
459471
if mb.result.DeleteResults == nil {
460472
mb.result.DeleteResults = make(map[int]ClientDeleteResult)
461473
}
462474
mb.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)}
475+
476+
return true
477+
}
478+
479+
func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool {
463480
if err := cur.extractError(); err != nil {
464481
err.Raw = raw
482+
if mb.writeErrors == nil {
483+
mb.writeErrors = make(map[int]WriteError)
484+
}
465485
mb.writeErrors[int(cur.Idx)] = *err
466486
return false
467487
}
468-
return true
469-
}
470488

471-
func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool {
472489
if mb.result.InsertResults == nil {
473490
mb.result.InsertResults = make(map[int]ClientInsertResult)
474491
}
475492
mb.result.InsertResults[int(cur.Idx)] = ClientInsertResult{mb.newIDMap[int(cur.Idx)]}
493+
494+
return true
495+
}
496+
497+
func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
476498
if err := cur.extractError(); err != nil {
477499
err.Raw = raw
500+
if mb.writeErrors == nil {
501+
mb.writeErrors = make(map[int]WriteError)
502+
}
478503
mb.writeErrors[int(cur.Idx)] = *err
479504
return false
480505
}
481-
return true
482-
}
483506

484-
func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
485507
if mb.result.UpdateResults == nil {
486508
mb.result.UpdateResults = make(map[int]ClientUpdateResult)
487509
}
@@ -495,11 +517,7 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
495517
result.UpsertedID = cur.Upserted.ID
496518
}
497519
mb.result.UpdateResults[int(cur.Idx)] = result
498-
if err := cur.extractError(); err != nil {
499-
err.Raw = raw
500-
mb.writeErrors[int(cur.Idx)] = *err
501-
return false
502-
}
520+
503521
return true
504522
}
505523

0 commit comments

Comments
 (0)