diff --git a/mongo/change_stream.go b/mongo/change_stream.go index bde1ebc800..009e68e4e4 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -339,8 +339,8 @@ AggregateExecuteLoop: break AggregateExecuteLoop } - switch tt := err.(type) { - case driver.Error: + var tt driver.Error + if errors.As(err, &tt) { // If error is not retryable, do not retry. if !tt.RetryableRead() { break AggregateExecuteLoop @@ -370,7 +370,7 @@ AggregateExecuteLoop: // Reset deployment. cs.aggregate.Deployment(cs.createOperationDeployment(server, conn)) - default: + } else { // Do not retry if error is not a driver error. break AggregateExecuteLoop } diff --git a/mongo/collection.go b/mongo/collection.go index a9279f1381..4e7a7ccb95 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1045,7 +1045,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption err = op.Execute(a.ctx) if err != nil { - if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil { + var wce driver.WriteCommandError + if errors.As(err, &wce) && wce.WriteConcernError != nil { return nil, *convertDriverWriteConcernError(wce.WriteConcernError) } return nil, replaceErrors(err) @@ -2041,8 +2042,8 @@ func (coll *Collection) drop(ctx context.Context) error { err = op.Execute(ctx) // ignore namespace not found errors - driverErr, ok := err.(driver.Error) - if !ok || (ok && !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() { return replaceErrors(err) } return nil diff --git a/mongo/database.go b/mongo/database.go index 414971c870..40a1760dc5 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -343,8 +343,8 @@ func (db *Database) Drop(ctx context.Context) error { err = op.Execute(ctx) - driverErr, ok := err.(driver.Error) - if err != nil && (!ok || !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) { return replaceErrors(err) } return nil diff --git a/mongo/errors.go b/mongo/errors.go index 07d713fd43..4d8cf99f11 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -57,7 +57,7 @@ func replaceErrors(err error) error { // ignored. For non-DDL write commands (insert, update, etc), acknowledgement // should be be propagated at the result-level: e.g., // SingleResult.Acknowledged. - if err == driver.ErrUnacknowledgedWrite { + if errors.Is(err, driver.ErrUnacknowledgedWrite) { return nil } @@ -721,12 +721,12 @@ func processWriteError(err error) (returnResult, error) { // ignored. For non-DDL write commands (insert, update, etc), acknowledgement // should be be propagated at the result-level: e.g., // SingleResult.Acknowledged. - if err == driver.ErrUnacknowledgedWrite { + if errors.Is(err, driver.ErrUnacknowledgedWrite) { return rrAllUnacknowledged, nil } - wce, ok := err.(driver.WriteCommandError) - if !ok { + var wce driver.WriteCommandError + if !errors.As(err, &wce) { return rrNone, replaceErrors(err) } diff --git a/mongo/index_view.go b/mongo/index_view.go index 748957da1b..07b7b16b08 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -51,13 +51,6 @@ type IndexModel struct { Options *options.IndexOptionsBuilder } -func isNamespaceNotFoundError(err error) bool { - if de, ok := err.(driver.Error); ok { - return de.Code == 26 - } - return false -} - // List executes a listIndexes command and returns a cursor over the indexes in the collection. // // The opts parameter can be used to specify options for this operation (see the options.ListIndexesOptions @@ -120,7 +113,8 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis if err != nil { // for namespaceNotFound errors, return an empty cursor and do not throw an error closeImplicitSession(sess) - if isNamespaceNotFoundError(err) { + var de driver.Error + if errors.As(err, &de) && de.NamespaceNotFound() { return newEmptyCursor(), nil } diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index f62e4769ca..d0e1007051 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "errors" "fmt" "strconv" @@ -229,7 +230,8 @@ func (siv SearchIndexView) DropOne( Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) - if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { + var de driver.Error + if errors.As(err, &de) && de.NamespaceNotFound() { return nil } return err diff --git a/mongo/session.go b/mongo/session.go index fb4da5bb05..418b06d3d0 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -193,7 +193,8 @@ func (s *Session) WithTransaction( default: } - if cerr, ok := err.(CommandError); ok { + var cerr CommandError + if errors.As(err, &cerr) { if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index d8901fe57e..0a74f9deeb 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -55,7 +55,8 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + var ce CommandError + if !errors.As(err, &ce) || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } } @@ -115,8 +116,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}} }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) }) @@ -148,8 +149,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, err }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult), "expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr) }) @@ -181,8 +182,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, err }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) }) @@ -392,7 +393,8 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + var ce CommandError + if !errors.As(err, &ce) || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } }