@@ -36,6 +36,7 @@ type Insert struct {
36
36
WriteConcern * writeconcern.WriteConcern
37
37
Session * session.Client
38
38
39
+ batches []* Write
39
40
result result.Insert
40
41
err error
41
42
continueOnError bool
@@ -89,13 +90,13 @@ splitInserts:
89
90
90
91
// Encode will encode this command into a wire message for the given server description.
91
92
func (i * Insert ) Encode (desc description.SelectedServer ) ([]wiremessage.WireMessage , error ) {
92
- cmds , err := i .encode (desc )
93
+ err := i .encode (desc )
93
94
if err != nil {
94
95
return nil , err
95
96
}
96
97
97
- wms := make ([]wiremessage.WireMessage , len (cmds ))
98
- for _ , cmd := range cmds {
98
+ wms := make ([]wiremessage.WireMessage , len (i . batches ))
99
+ for _ , cmd := range i . batches {
99
100
wm , err := cmd .Encode (desc )
100
101
if err != nil {
101
102
return nil , err
@@ -143,23 +144,21 @@ func (i *Insert) encodeBatch(docs []*bson.Document, desc description.SelectedSer
143
144
}, nil
144
145
}
145
146
146
- func (i * Insert ) encode (desc description.SelectedServer ) ([]* Write , error ) {
147
- out := []* Write {}
147
+ func (i * Insert ) encode (desc description.SelectedServer ) error {
148
148
batches , err := i .split (int (desc .MaxBatchCount ), int (desc .MaxDocumentSize ))
149
149
if err != nil {
150
- return nil , err
150
+ return err
151
151
}
152
152
153
153
for _ , docs := range batches {
154
154
cmd , err := i .encodeBatch (docs , desc )
155
155
if err != nil {
156
- return nil , err
156
+ return err
157
157
}
158
158
159
- out = append (out , cmd )
159
+ i . batches = append (i . batches , cmd )
160
160
}
161
-
162
- return out , nil
161
+ return nil
163
162
}
164
163
165
164
// Decode will decode the wire message using the provided server description. Errors during decoding
@@ -193,14 +192,25 @@ func (i *Insert) Err() error { return i.err }
193
192
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
194
193
func (i * Insert ) RoundTrip (ctx context.Context , desc description.SelectedServer , rw wiremessage.ReadWriter ) (result.Insert , error ) {
195
194
res := result.Insert {}
196
- cmds , err := i .encode (desc )
197
- if err != nil {
198
- return res , err
195
+ if i .batches == nil {
196
+ err := i .encode (desc )
197
+ if err != nil {
198
+ return res , err
199
+ }
199
200
}
200
201
201
- for _ , cmd := range cmds {
202
+ // hold onto txnNumber, reset it when loop exits to ensure reuse of same
203
+ // transaction number if retry is needed
204
+ var txnNumber int64
205
+ if i .Session != nil && i .Session .RetryWrite {
206
+ txnNumber = i .Session .TxnNumber
207
+ }
208
+ for j , cmd := range i .batches {
202
209
rdr , err := cmd .RoundTrip (ctx , desc , rw )
203
210
if err != nil {
211
+ if i .Session != nil && i .Session .RetryWrite {
212
+ i .Session .TxnNumber = txnNumber + int64 (j )
213
+ }
204
214
return res , err
205
215
}
206
216
@@ -213,13 +223,29 @@ func (i *Insert) RoundTrip(ctx context.Context, desc description.SelectedServer,
213
223
214
224
if r .WriteConcernError != nil {
215
225
res .WriteConcernError = r .WriteConcernError
226
+ if i .Session != nil && i .Session .RetryWrite {
227
+ i .Session .TxnNumber = txnNumber
228
+ return res , nil // report writeconcernerror for retry
229
+ }
216
230
}
217
231
218
232
res .N += r .N
219
233
220
234
if ! i .continueOnError && len (res .WriteErrors ) > 0 {
221
235
return res , nil
222
236
}
237
+
238
+ // Increment txnNumber for each batch
239
+ if i .Session != nil && i .Session .RetryWrite {
240
+ i .Session .IncrementTxnNumber ()
241
+ i .batches = i .batches [1 :] // if batch encoded successfully, remove it from the slice
242
+ }
243
+ }
244
+
245
+ if i .Session != nil && i .Session .RetryWrite {
246
+ // if retryable write succeeded, transaction number will be incremented one extra time,
247
+ // so we decrement it here
248
+ i .Session .TxnNumber --
223
249
}
224
250
225
251
return res , nil
0 commit comments