diff --git a/mongo/errors.go b/mongo/errors.go index 86bc3310a3..004d73657b 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -252,9 +252,25 @@ type ServerError interface { // HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message. HasErrorCodeWithMessage(int, string) bool + // ErrorCodes returns all error codes (unsorted) in the server’s response. + // This would include nested errors (e.g., write concern errors) for + // supporting implementations (e.g., BulkWriteException) as well as the + // top-level error code. + ErrorCodes() []int + serverError() } +func hasErrorCode(srvErr ServerError, code int) bool { + for _, srvErrCode := range srvErr.ErrorCodes() { + if code == srvErrCode { + return true + } + } + + return false +} + var _ ServerError = CommandError{} var _ ServerError = WriteError{} var _ ServerError = WriteException{} @@ -290,6 +306,11 @@ func (e CommandError) HasErrorCode(code int) bool { return int(e.Code) == code } +// ErrorCodes returns a list of error codes returned by the server. +func (e CommandError) ErrorCodes() []int { + return []int{int(e.Code)} +} + // HasErrorLabel returns true if the error contains the specified label. func (e CommandError) HasErrorLabel(label string) bool { for _, l := range e.Labels { @@ -345,6 +366,11 @@ func (we WriteError) HasErrorCode(code int) bool { return we.Code == code } +// ErrorCodes returns a list of error codes returned by the server. +func (we WriteError) ErrorCodes() []int { + return []int{we.Code} +} + // HasErrorLabel returns true if the error contains the specified label. WriteErrors do not contain labels, // so we always return false. func (we WriteError) HasErrorLabel(string) bool { @@ -451,15 +477,21 @@ func (mwe WriteException) Error() string { // HasErrorCode returns true if the error has the specified code. func (mwe WriteException) HasErrorCode(code int) bool { - if mwe.WriteConcernError != nil && mwe.WriteConcernError.Code == code { - return true + return hasErrorCode(mwe, code) +} + +// ErrorCodes returns a list of error codes returned by the server. +func (mwe WriteException) ErrorCodes() []int { + errorCodes := []int{} + for _, writeError := range mwe.WriteErrors { + errorCodes = append(errorCodes, writeError.Code) } - for _, we := range mwe.WriteErrors { - if we.Code == code { - return true - } + + if mwe.WriteConcernError != nil { + errorCodes = append(errorCodes, mwe.WriteConcernError.Code) } - return false + + return errorCodes } // HasErrorLabel returns true if the error contains the specified label. @@ -563,15 +595,21 @@ func (bwe BulkWriteException) Error() string { // HasErrorCode returns true if any of the errors have the specified code. func (bwe BulkWriteException) HasErrorCode(code int) bool { - if bwe.WriteConcernError != nil && bwe.WriteConcernError.Code == code { - return true + return hasErrorCode(bwe, code) +} + +// ErrorCodes returns a list of error codes returned by the server. +func (bwe BulkWriteException) ErrorCodes() []int { + errorCodes := []int{} + for _, writeError := range bwe.WriteErrors { + errorCodes = append(errorCodes, writeError.Code) } - for _, we := range bwe.WriteErrors { - if we.Code == code { - return true - } + + if bwe.WriteConcernError != nil { + errorCodes = append(errorCodes, bwe.WriteConcernError.Code) } - return false + + return errorCodes } // HasErrorLabel returns true if the error contains the specified label. diff --git a/mongo/errors_test.go b/mongo/errors_test.go index c39d409ac5..2ff04c4dd2 100644 --- a/mongo/errors_test.go +++ b/mongo/errors_test.go @@ -679,6 +679,70 @@ func TestIsTimeout(t *testing.T) { } } +func TestServerError_ErrorCodes(t *testing.T) { + tests := []struct { + name string + error ServerError + want []int + }{ + { + name: "CommandError", + error: CommandError{Code: 1}, + want: []int{1}, + }, + { + name: "WriteError", + error: WriteError{Code: 1}, + want: []int{1}, + }, + { + name: "WriteException single", + error: WriteException{WriteErrors: []WriteError{{Code: 1}}}, + want: []int{1}, + }, + { + name: "WriteException multiple", + error: WriteException{WriteErrors: []WriteError{{Code: 1}, {Code: 2}}}, + want: []int{1, 2}, + }, + { + name: "WriteException duplicates", + error: WriteException{WriteErrors: []WriteError{{Code: 1}, {Code: 2}, {Code: 2}}}, + want: []int{1, 2, 2}, + }, + { + name: "BulkWriteException single", + error: BulkWriteException{WriteErrors: []BulkWriteError{{WriteError: WriteError{Code: 1}}}}, + want: []int{1}, + }, + { + name: "BulkWriteException multiple", + error: BulkWriteException{WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 1}}, + {WriteError: WriteError{Code: 2}}, + }}, + want: []int{1, 2}, + }, + { + name: "BulkWriteException duplicates", + error: BulkWriteException{WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 1}}, + {WriteError: WriteError{Code: 2}}, + {WriteError: WriteError{Code: 2}}, + }}, + want: []int{1, 2, 2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.error.ErrorCodes() + + assert.ElementsMatch(t, got, test.want) + }) + } +} + type netErr struct { timeout bool }