@@ -243,16 +243,16 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
243
243
conn , err := s .pool .get (ctx )
244
244
if err != nil {
245
245
s .sem .Release (1 )
246
- connErr , ok := err .( ConnectionError )
247
- if ! ok {
246
+ wrappedConnErr := unwrapConnectionError ( err )
247
+ if wrappedConnErr == nil {
248
248
return nil , err
249
249
}
250
250
251
251
// Since the only kind of ConnectionError we receive from pool.Get will be an initialization
252
252
// error, we should set the description.Server appropriately.
253
253
desc := description.Server {
254
254
Kind : description .Unknown ,
255
- LastError : connErr . Wrapped ,
255
+ LastError : wrappedConnErr ,
256
256
}
257
257
s .updateDescription (desc , false )
258
258
@@ -317,8 +317,9 @@ func (s *Server) RequestImmediateCheck() {
317
317
318
318
// ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
319
319
func (s * Server ) ProcessError (err error ) {
320
- // Invalidate server description if not master or node recovering error occurs
321
- if cerr , ok := err .(driver.Error ); ok && (cerr .NetworkError () || cerr .NodeIsRecovering () || cerr .NotMaster ()) {
320
+ // Invalidate server description if not master or node recovering error occurs.
321
+ // These errors can be reported as a command error or a write concern error.
322
+ if cerr , ok := err .(driver.Error ); ok && (cerr .NodeIsRecovering () || cerr .NotMaster ()) {
322
323
desc := s .Description ()
323
324
desc .Kind = description .Unknown
324
325
desc .LastError = err
@@ -345,15 +346,16 @@ func (s *Server) ProcessError(err error) {
345
346
return
346
347
}
347
348
348
- ne , ok := err .( ConnectionError )
349
- if ! ok {
349
+ wrappedConnErr := unwrapConnectionError ( err )
350
+ if wrappedConnErr == nil {
350
351
return
351
352
}
352
353
353
- if netErr , ok := ne .Wrapped .(net.Error ); ok && netErr .Timeout () {
354
+ // Ignore transient timeout errors.
355
+ if netErr , ok := wrappedConnErr .(net.Error ); ok && netErr .Timeout () {
354
356
return
355
357
}
356
- if ne . Wrapped == context .Canceled || ne . Wrapped == context .DeadlineExceeded {
358
+ if wrappedConnErr == context .Canceled || wrappedConnErr == context .DeadlineExceeded {
357
359
return
358
360
}
359
361
@@ -362,6 +364,7 @@ func (s *Server) ProcessError(err error) {
362
364
desc .LastError = err
363
365
// updates description to unknown
364
366
s .updateDescription (desc , false )
367
+ s .pool .clear ()
365
368
}
366
369
367
370
// update handles performing heartbeats and updating any subscribers of the
@@ -551,7 +554,7 @@ func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
551
554
if err != nil {
552
555
saved = err
553
556
conn = nil
554
- if _ , ok := err .( ConnectionError ); ok {
557
+ if wrappedConnErr := unwrapConnectionError ( err ); wrappedConnErr != nil {
555
558
s .pool .drain ()
556
559
// If the server is not connected, give up and exit loop
557
560
if s .Description ().Kind == description .Unknown {
@@ -637,3 +640,18 @@ func (ss *ServerSubscription) Unsubscribe() error {
637
640
638
641
return nil
639
642
}
643
+
644
+ // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
645
+ func unwrapConnectionError (err error ) error {
646
+ connErr , ok := err .(ConnectionError )
647
+ if ok {
648
+ return connErr .Wrapped
649
+ }
650
+
651
+ driverErr , ok := err .(driver.Error )
652
+ if ok && driverErr .NetworkError () {
653
+ return driverErr .Wrapped
654
+ }
655
+
656
+ return nil
657
+ }
0 commit comments