Skip to content

Commit 9443c03

Browse files
authored
GODRIVER-2721 Fully support "errors.Is" and "errors.As". (#1969)
1 parent 6a504a3 commit 9443c03

File tree

8 files changed

+30
-30
lines changed

8 files changed

+30
-30
lines changed

mongo/change_stream.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ AggregateExecuteLoop:
339339
break AggregateExecuteLoop
340340
}
341341

342-
switch tt := err.(type) {
343-
case driver.Error:
342+
var tt driver.Error
343+
if errors.As(err, &tt) {
344344
// If error is not retryable, do not retry.
345345
if !tt.RetryableRead() {
346346
break AggregateExecuteLoop
@@ -370,7 +370,7 @@ AggregateExecuteLoop:
370370

371371
// Reset deployment.
372372
cs.aggregate.Deployment(cs.createOperationDeployment(server, conn))
373-
default:
373+
} else {
374374
// Do not retry if error is not a driver error.
375375
break AggregateExecuteLoop
376376
}

mongo/collection.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
10451045

10461046
err = op.Execute(a.ctx)
10471047
if err != nil {
1048-
if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil {
1048+
var wce driver.WriteCommandError
1049+
if errors.As(err, &wce) && wce.WriteConcernError != nil {
10491050
return nil, *convertDriverWriteConcernError(wce.WriteConcernError)
10501051
}
10511052
return nil, replaceErrors(err)
@@ -2041,8 +2042,8 @@ func (coll *Collection) drop(ctx context.Context) error {
20412042
err = op.Execute(ctx)
20422043

20432044
// ignore namespace not found errors
2044-
driverErr, ok := err.(driver.Error)
2045-
if !ok || (ok && !driverErr.NamespaceNotFound()) {
2045+
var driverErr driver.Error
2046+
if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() {
20462047
return replaceErrors(err)
20472048
}
20482049
return nil

mongo/database.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ func (db *Database) Drop(ctx context.Context) error {
346346

347347
err = op.Execute(ctx)
348348

349-
driverErr, ok := err.(driver.Error)
350-
if err != nil && (!ok || !driverErr.NamespaceNotFound()) {
349+
var driverErr driver.Error
350+
if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) {
351351
return replaceErrors(err)
352352
}
353353
return nil

mongo/errors.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func replaceErrors(err error) error {
5757
// ignored. For non-DDL write commands (insert, update, etc), acknowledgement
5858
// should be be propagated at the result-level: e.g.,
5959
// SingleResult.Acknowledged.
60-
if err == driver.ErrUnacknowledgedWrite {
60+
if errors.Is(err, driver.ErrUnacknowledgedWrite) {
6161
return nil
6262
}
6363

@@ -721,12 +721,12 @@ func processWriteError(err error) (returnResult, error) {
721721
// ignored. For non-DDL write commands (insert, update, etc), acknowledgement
722722
// should be be propagated at the result-level: e.g.,
723723
// SingleResult.Acknowledged.
724-
if err == driver.ErrUnacknowledgedWrite {
724+
if errors.Is(err, driver.ErrUnacknowledgedWrite) {
725725
return rrAllUnacknowledged, nil
726726
}
727727

728-
wce, ok := err.(driver.WriteCommandError)
729-
if !ok {
728+
var wce driver.WriteCommandError
729+
if !errors.As(err, &wce) {
730730
return rrNone, replaceErrors(err)
731731
}
732732

mongo/index_view.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ type IndexModel struct {
5151
Options *options.IndexOptionsBuilder
5252
}
5353

54-
func isNamespaceNotFoundError(err error) bool {
55-
if de, ok := err.(driver.Error); ok {
56-
return de.Code == 26
57-
}
58-
return false
59-
}
60-
6154
// List executes a listIndexes command and returns a cursor over the indexes in the collection.
6255
//
6356
// The opts parameter can be used to specify options for this operation (see the options.ListIndexesOptions
@@ -120,7 +113,8 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis
120113
if err != nil {
121114
// for namespaceNotFound errors, return an empty cursor and do not throw an error
122115
closeImplicitSession(sess)
123-
if isNamespaceNotFoundError(err) {
116+
var de driver.Error
117+
if errors.As(err, &de) && de.NamespaceNotFound() {
124118
return newEmptyCursor(), nil
125119
}
126120

mongo/search_index_view.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package mongo
88

99
import (
1010
"context"
11+
"errors"
1112
"fmt"
1213
"strconv"
1314

@@ -229,7 +230,8 @@ func (siv SearchIndexView) DropOne(
229230
Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator)
230231

231232
err = op.Execute(ctx)
232-
if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() {
233+
var de driver.Error
234+
if errors.As(err, &de) && de.NamespaceNotFound() {
233235
return nil
234236
}
235237
return err

mongo/session.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ func (s *Session) WithTransaction(
193193
default:
194194
}
195195

196-
if cerr, ok := err.(CommandError); ok {
196+
var cerr CommandError
197+
if errors.As(err, &cerr) {
197198
if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
198199
continue
199200
}

mongo/with_transactions_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ func TestConvenientTransactions(t *testing.T) {
5555
{"killAllSessions", bson.A{}},
5656
}).Err()
5757
if err != nil {
58-
if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
58+
var ce CommandError
59+
if !errors.As(err, &ce) || ce.Code != errorInterrupted {
5960
t.Fatalf("killAllSessions error: %v", err)
6061
}
6162
}
@@ -115,8 +116,8 @@ func TestConvenientTransactions(t *testing.T) {
115116
return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
116117
})
117118
assert.NotNil(t, err, "expected WithTransaction error, got nil")
118-
cmdErr, ok := err.(CommandError)
119-
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
119+
var cmdErr CommandError
120+
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
120121
assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
121122
"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
122123
})
@@ -148,8 +149,8 @@ func TestConvenientTransactions(t *testing.T) {
148149
return nil, err
149150
})
150151
assert.NotNil(t, err, "expected WithTransaction error, got nil")
151-
cmdErr, ok := err.(CommandError)
152-
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
152+
var cmdErr CommandError
153+
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
153154
assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
154155
"expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
155156
})
@@ -181,8 +182,8 @@ func TestConvenientTransactions(t *testing.T) {
181182
return nil, err
182183
})
183184
assert.NotNil(t, err, "expected WithTransaction error, got nil")
184-
cmdErr, ok := err.(CommandError)
185-
assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
185+
var cmdErr CommandError
186+
assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err)
186187
assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
187188
"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
188189
})
@@ -392,7 +393,8 @@ func TestConvenientTransactions(t *testing.T) {
392393
{"killAllSessions", bson.A{}},
393394
}).Err()
394395
if err != nil {
395-
if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
396+
var ce CommandError
397+
if !errors.As(err, &ce) || ce.Code != errorInterrupted {
396398
t.Fatalf("killAllSessions error: %v", err)
397399
}
398400
}

0 commit comments

Comments
 (0)