@@ -240,7 +240,6 @@ func (cs *ChangeStream) createOperationDeployment(server driver.Server, connecti
240
240
func (cs * ChangeStream ) executeOperation (ctx context.Context , resuming bool ) error {
241
241
var server driver.Server
242
242
var conn driver.Connection
243
- var err error
244
243
245
244
if server , cs .err = cs .client .deployment .SelectServer (ctx , cs .selector ); cs .err != nil {
246
245
return cs .Err ()
@@ -284,48 +283,65 @@ func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) err
284
283
// Cancel the timeout-derived context at the end of executeOperation to avoid a context leak.
285
284
defer cancelFunc ()
286
285
}
287
- if original := cs .aggregate .Execute (ctx ); original != nil {
288
- retryableRead := cs .client .retryReads && cs .wireVersion != nil && cs .wireVersion .Max >= 6
289
- if ! retryableRead {
290
- cs .err = replaceErrors (original )
291
- return cs .err
286
+
287
+ // Execute the aggregate, retrying on retryable errors once (1) if retryable reads are enabled and
288
+ // infinitely (-1) if context is a Timeout context.
289
+ var retries int
290
+ if cs .client .retryReads && cs .wireVersion != nil && cs .wireVersion .Max >= 6 {
291
+ retries = 1
292
+ }
293
+ if internal .IsTimeoutContext (ctx ) {
294
+ retries = - 1
295
+ }
296
+
297
+ var err error
298
+ AggregateExecuteLoop:
299
+ for {
300
+ err = cs .aggregate .Execute (ctx )
301
+ // If no error or no retries remain, do not retry.
302
+ if err == nil || retries == 0 {
303
+ break AggregateExecuteLoop
292
304
}
293
305
294
- cs .err = original
295
- switch tt := original .(type ) {
306
+ switch tt := err .(type ) {
296
307
case driver.Error :
308
+ // If error is not retryable, do not retry.
297
309
if ! tt .RetryableRead () {
298
- break
310
+ break AggregateExecuteLoop
299
311
}
300
312
313
+ // If error is retryable: subtract 1 from retries, redo server selection, checkout
314
+ // a connection, and restart loop.
315
+ retries --
301
316
server , err = cs .client .deployment .SelectServer (ctx , cs .selector )
302
317
if err != nil {
303
- break
318
+ break AggregateExecuteLoop
304
319
}
305
320
306
321
conn .Close ()
307
322
conn , err = server .Connection (ctx )
308
323
if err != nil {
309
- break
324
+ break AggregateExecuteLoop
310
325
}
311
326
defer conn .Close ()
312
- cs .wireVersion = conn .Description ().WireVersion
313
327
328
+ // If wire version is now < 6, do not retry.
329
+ cs .wireVersion = conn .Description ().WireVersion
314
330
if cs .wireVersion == nil || cs .wireVersion .Max < 6 {
315
- break
331
+ break AggregateExecuteLoop
316
332
}
317
333
334
+ // Reset deployment.
318
335
cs .aggregate .Deployment (cs .createOperationDeployment (server , conn ))
319
- cs .err = cs .aggregate .Execute (ctx )
336
+ default :
337
+ // Do not retry if error is not a driver error.
338
+ break AggregateExecuteLoop
320
339
}
321
-
322
- if cs .err != nil {
323
- cs .err = replaceErrors (cs .err )
324
- return cs .Err ()
325
- }
326
-
327
340
}
328
- cs .err = nil
341
+ if err != nil {
342
+ cs .err = replaceErrors (err )
343
+ return cs .err
344
+ }
329
345
330
346
cr := cs .aggregate .ResultCursorResponse ()
331
347
cr .Server = server
0 commit comments