Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ AggregateExecuteLoop:
break AggregateExecuteLoop
}

switch tt := err.(type) {
case driver.Error:
var tt driver.Error
if errors.As(err, &tt) {
// If error is not retryable, do not retry.
if !tt.RetryableRead() {
break AggregateExecuteLoop
Expand Down Expand Up @@ -370,7 +370,7 @@ AggregateExecuteLoop:

// Reset deployment.
cs.aggregate.Deployment(cs.createOperationDeployment(server, conn))
default:
} else {
// Do not retry if error is not a driver error.
break AggregateExecuteLoop
}
Expand Down
7 changes: 4 additions & 3 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption

err = op.Execute(a.ctx)
if err != nil {
if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil {
var wce driver.WriteCommandError
if errors.As(err, &wce) && wce.WriteConcernError != nil {
return nil, *convertDriverWriteConcernError(wce.WriteConcernError)
}
return nil, replaceErrors(err)
Expand Down Expand Up @@ -2041,8 +2042,8 @@ func (coll *Collection) drop(ctx context.Context) error {
err = op.Execute(ctx)

// ignore namespace not found errors
driverErr, ok := err.(driver.Error)
if !ok || (ok && !driverErr.NamespaceNotFound()) {
var driverErr driver.Error
if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() {
return replaceErrors(err)
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions mongo/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ func (db *Database) Drop(ctx context.Context) error {

err = op.Execute(ctx)

driverErr, ok := err.(driver.Error)
if err != nil && (!ok || !driverErr.NamespaceNotFound()) {
var driverErr driver.Error
if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) {
return replaceErrors(err)
}
return nil
Expand Down
8 changes: 4 additions & 4 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func replaceErrors(err error) error {
// ignored. For non-DDL write commands (insert, update, etc), acknowledgement
// should be be propagated at the result-level: e.g.,
// SingleResult.Acknowledged.
if err == driver.ErrUnacknowledgedWrite {
if errors.Is(err, driver.ErrUnacknowledgedWrite) {
return nil
}

Expand Down Expand Up @@ -721,12 +721,12 @@ func processWriteError(err error) (returnResult, error) {
// ignored. For non-DDL write commands (insert, update, etc), acknowledgement
// should be be propagated at the result-level: e.g.,
// SingleResult.Acknowledged.
if err == driver.ErrUnacknowledgedWrite {
if errors.Is(err, driver.ErrUnacknowledgedWrite) {
return rrAllUnacknowledged, nil
}

wce, ok := err.(driver.WriteCommandError)
if !ok {
var wce driver.WriteCommandError
if !errors.As(err, &wce) {
return rrNone, replaceErrors(err)
}

Expand Down
3 changes: 2 additions & 1 deletion mongo/index_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ type IndexModel struct {
}

func isNamespaceNotFoundError(err error) bool {
if de, ok := err.(driver.Error); ok {
var de driver.Error
if errors.As(err, &de) {
return de.Code == 26
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Can we use the NamespaceNotFound method here?
Optional: If we use NamespaceNotFound, can we remove this function?

}
return false
Expand Down
4 changes: 3 additions & 1 deletion mongo/search_index_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package mongo

import (
"context"
"errors"
"fmt"
"strconv"

Expand Down Expand Up @@ -229,7 +230,8 @@ func (siv SearchIndexView) DropOne(
Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator)

err = op.Execute(ctx)
if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() {
var de driver.Error
if errors.As(err, &de) && de.NamespaceNotFound() {
return nil
}
return err
Expand Down
3 changes: 2 additions & 1 deletion mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ func (s *Session) WithTransaction(
default:
}

if cerr, ok := err.(CommandError); ok {
var cerr CommandError
if errors.As(err, &cerr) {
if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
continue
}
Expand Down
18 changes: 10 additions & 8 deletions mongo/with_transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ func TestConvenientTransactions(t *testing.T) {
{"killAllSessions", bson.A{}},
}).Err()
if err != nil {
if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
var ce CommandError
if !errors.As(err, &ce) || ce.Code != errorInterrupted {
t.Fatalf("killAllSessions error: %v", err)
}
}
Expand Down Expand Up @@ -115,8 +116,8 @@ func TestConvenientTransactions(t *testing.T) {
return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
})
assert.NotNil(t, err, "expected WithTransaction error, got nil")
cmdErr, ok := err.(CommandError)
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
var cmdErr CommandError
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
})
Expand Down Expand Up @@ -148,8 +149,8 @@ func TestConvenientTransactions(t *testing.T) {
return nil, err
})
assert.NotNil(t, err, "expected WithTransaction error, got nil")
cmdErr, ok := err.(CommandError)
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
var cmdErr CommandError
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
"expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
})
Expand Down Expand Up @@ -181,8 +182,8 @@ func TestConvenientTransactions(t *testing.T) {
return nil, err
})
assert.NotNil(t, err, "expected WithTransaction error, got nil")
cmdErr, ok := err.(CommandError)
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
var cmdErr CommandError
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
})
Expand Down Expand Up @@ -392,7 +393,8 @@ func TestConvenientTransactions(t *testing.T) {
{"killAllSessions", bson.A{}},
}).Err()
if err != nil {
if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
var ce CommandError
if !errors.As(err, &ce) || ce.Code != errorInterrupted {
t.Fatalf("killAllSessions error: %v", err)
}
}
Expand Down
Loading