From 01427461fe25efcdb4bae6931717dcdf0a600efb Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 4 Mar 2025 11:53:28 -0500 Subject: [PATCH 1/4] use errors functions in mongo package. --- mongo/change_stream.go | 6 +++--- mongo/collection.go | 7 ++++--- mongo/database.go | 4 ++-- mongo/errors.go | 16 ++++++++-------- mongo/index_view.go | 3 ++- mongo/search_index_view.go | 4 +++- mongo/session.go | 3 ++- mongo/with_transactions_test.go | 18 ++++++++++-------- 8 files changed, 34 insertions(+), 27 deletions(-) diff --git a/mongo/change_stream.go b/mongo/change_stream.go index bde1ebc800..009e68e4e4 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -339,8 +339,8 @@ AggregateExecuteLoop: break AggregateExecuteLoop } - switch tt := err.(type) { - case driver.Error: + var tt driver.Error + if errors.As(err, &tt) { // If error is not retryable, do not retry. if !tt.RetryableRead() { break AggregateExecuteLoop @@ -370,7 +370,7 @@ AggregateExecuteLoop: // Reset deployment. cs.aggregate.Deployment(cs.createOperationDeployment(server, conn)) - default: + } else { // Do not retry if error is not a driver error. break AggregateExecuteLoop } diff --git a/mongo/collection.go b/mongo/collection.go index a9279f1381..4e7a7ccb95 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1045,7 +1045,8 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption err = op.Execute(a.ctx) if err != nil { - if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil { + var wce driver.WriteCommandError + if errors.As(err, &wce) && wce.WriteConcernError != nil { return nil, *convertDriverWriteConcernError(wce.WriteConcernError) } return nil, replaceErrors(err) @@ -2041,8 +2042,8 @@ func (coll *Collection) drop(ctx context.Context) error { err = op.Execute(ctx) // ignore namespace not found errors - driverErr, ok := err.(driver.Error) - if !ok || (ok && !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() { return replaceErrors(err) } return nil diff --git a/mongo/database.go b/mongo/database.go index 414971c870..40a1760dc5 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -343,8 +343,8 @@ func (db *Database) Drop(ctx context.Context) error { err = op.Execute(ctx) - driverErr, ok := err.(driver.Error) - if err != nil && (!ok || !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) { return replaceErrors(err) } return nil diff --git a/mongo/errors.go b/mongo/errors.go index 07d713fd43..6c95af3438 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -57,14 +57,14 @@ func replaceErrors(err error) error { // ignored. For non-DDL write commands (insert, update, etc), acknowledgement // should be be propagated at the result-level: e.g., // SingleResult.Acknowledged. - if err == driver.ErrUnacknowledgedWrite { + if errors.Is(err, driver.ErrUnacknowledgedWrite) { return nil } if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } - if de, ok := err.(driver.Error); ok { + if de := new(driver.Error); errors.As(err, de) { return CommandError{ Code: de.Code, Message: de.Message, @@ -74,7 +74,7 @@ func replaceErrors(err error) error { Raw: bson.Raw(de.Raw), } } - if qe, ok := err.(driver.QueryFailureError); ok { + if qe := new(driver.QueryFailureError); errors.As(err, qe) { // qe.Message is "command failure" ce := CommandError{ Name: qe.Message, @@ -93,7 +93,7 @@ func replaceErrors(err error) error { return ce } - if me, ok := err.(mongocrypt.Error); ok { + if me := new(mongocrypt.Error); errors.As(err, me) { return MongocryptError{Code: me.Code, Message: me.Message} } @@ -101,7 +101,7 @@ func replaceErrors(err error) error { return ErrNilValue } - if marshalErr, ok := err.(codecutil.MarshalError); ok { + if marshalErr := new(codecutil.MarshalError); errors.As(err, marshalErr) { return MarshalError{ Value: marshalErr.Value, Err: marshalErr.Err, @@ -721,12 +721,12 @@ func processWriteError(err error) (returnResult, error) { // ignored. For non-DDL write commands (insert, update, etc), acknowledgement // should be be propagated at the result-level: e.g., // SingleResult.Acknowledged. - if err == driver.ErrUnacknowledgedWrite { + if errors.Is(err, driver.ErrUnacknowledgedWrite) { return rrAllUnacknowledged, nil } - wce, ok := err.(driver.WriteCommandError) - if !ok { + var wce driver.WriteCommandError + if !errors.As(err, &wce) { return rrNone, replaceErrors(err) } diff --git a/mongo/index_view.go b/mongo/index_view.go index 748957da1b..e04f0c0e86 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -52,7 +52,8 @@ type IndexModel struct { } func isNamespaceNotFoundError(err error) bool { - if de, ok := err.(driver.Error); ok { + var de driver.Error + if errors.As(err, &de) { return de.Code == 26 } return false diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index f62e4769ca..d0e1007051 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "errors" "fmt" "strconv" @@ -229,7 +230,8 @@ func (siv SearchIndexView) DropOne( Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) - if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { + var de driver.Error + if errors.As(err, &de) && de.NamespaceNotFound() { return nil } return err diff --git a/mongo/session.go b/mongo/session.go index fb4da5bb05..418b06d3d0 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -193,7 +193,8 @@ func (s *Session) WithTransaction( default: } - if cerr, ok := err.(CommandError); ok { + var cerr CommandError + if errors.As(err, &cerr) { if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index d8901fe57e..0a74f9deeb 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -55,7 +55,8 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + var ce CommandError + if !errors.As(err, &ce) || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } } @@ -115,8 +116,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}} }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) }) @@ -148,8 +149,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, err }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult), "expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr) }) @@ -181,8 +182,8 @@ func TestConvenientTransactions(t *testing.T) { return nil, err }) assert.NotNil(t, err, "expected WithTransaction error, got nil") - cmdErr, ok := err.(CommandError) - assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err) + var cmdErr CommandError + assert.True(t, errors.As(err, &cmdErr), "expected error type %T, got %T", cmdErr, err) assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError), "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr) }) @@ -392,7 +393,8 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + var ce CommandError + if !errors.As(err, &ce) || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } } From beb9ad213b68c350fe08e96fd2a32a999489a1f7 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 5 Mar 2025 15:51:01 -0500 Subject: [PATCH 2/4] update tests --- internal/integration/client_side_encryption_prose_test.go | 2 +- mongo/errors.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index 8ab7e5f281..9e27bfbcbb 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -369,7 +369,7 @@ func TestClientSideEncryptionProse(t *testing.T) { assert.True(mt, strings.Contains(insertErr.Error(), "auth error"), "expected InsertOne auth error, got %v", insertErr) assert.True(mt, strings.Contains(encErr.Error(), "auth error"), - "expected Encrypt auth error, got %v", insertErr) + "expected Encrypt auth error, got %v", encErr) return } assert.Nil(mt, insertErr, "InsertOne error: %v", insertErr) diff --git a/mongo/errors.go b/mongo/errors.go index 6c95af3438..4d8cf99f11 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -64,7 +64,7 @@ func replaceErrors(err error) error { if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } - if de := new(driver.Error); errors.As(err, de) { + if de, ok := err.(driver.Error); ok { return CommandError{ Code: de.Code, Message: de.Message, @@ -74,7 +74,7 @@ func replaceErrors(err error) error { Raw: bson.Raw(de.Raw), } } - if qe := new(driver.QueryFailureError); errors.As(err, qe) { + if qe, ok := err.(driver.QueryFailureError); ok { // qe.Message is "command failure" ce := CommandError{ Name: qe.Message, @@ -93,7 +93,7 @@ func replaceErrors(err error) error { return ce } - if me := new(mongocrypt.Error); errors.As(err, me) { + if me, ok := err.(mongocrypt.Error); ok { return MongocryptError{Code: me.Code, Message: me.Message} } @@ -101,7 +101,7 @@ func replaceErrors(err error) error { return ErrNilValue } - if marshalErr := new(codecutil.MarshalError); errors.As(err, marshalErr) { + if marshalErr, ok := err.(codecutil.MarshalError); ok { return MarshalError{ Value: marshalErr.Value, Err: marshalErr.Err, From b8fe8175c1cca6e65f3d96ae28df7a946b69c0f5 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 10 Mar 2025 16:06:34 -0400 Subject: [PATCH 3/4] revert change in internal/integration/client_side_encryption_prose_test.go --- internal/integration/client_side_encryption_prose_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index 9e27bfbcbb..8ab7e5f281 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -369,7 +369,7 @@ func TestClientSideEncryptionProse(t *testing.T) { assert.True(mt, strings.Contains(insertErr.Error(), "auth error"), "expected InsertOne auth error, got %v", insertErr) assert.True(mt, strings.Contains(encErr.Error(), "auth error"), - "expected Encrypt auth error, got %v", encErr) + "expected Encrypt auth error, got %v", insertErr) return } assert.Nil(mt, insertErr, "InsertOne error: %v", insertErr) From e7582d473f00a14baebc80f2891408ffc6096dc8 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 18 Mar 2025 15:42:48 -0400 Subject: [PATCH 4/4] remove `isNamespaceNotFoundError()` --- mongo/index_view.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/mongo/index_view.go b/mongo/index_view.go index e04f0c0e86..07b7b16b08 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -51,14 +51,6 @@ type IndexModel struct { Options *options.IndexOptionsBuilder } -func isNamespaceNotFoundError(err error) bool { - var de driver.Error - if errors.As(err, &de) { - return de.Code == 26 - } - return false -} - // List executes a listIndexes command and returns a cursor over the indexes in the collection. // // The opts parameter can be used to specify options for this operation (see the options.ListIndexesOptions @@ -121,7 +113,8 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis if err != nil { // for namespaceNotFound errors, return an empty cursor and do not throw an error closeImplicitSession(sess) - if isNamespaceNotFoundError(err) { + var de driver.Error + if errors.As(err, &de) && de.NamespaceNotFound() { return newEmptyCursor(), nil }