Skip to content

Commit cd80b0a

Browse files
authored
refactor: Use errors package to compare and assert error types (#3739)
1 parent fd54574 commit cd80b0a

15 files changed

+107
-84
lines changed

example/basicauth/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package main
1717
import (
1818
"bufio"
1919
"context"
20+
"errors"
2021
"fmt"
2122
"os"
2223
"strings"
@@ -43,7 +44,7 @@ func main() {
4344
user, _, err := client.Users.Get(ctx, "")
4445

4546
// Is this a two-factor auth error? If so, prompt for OTP and try again.
46-
if _, ok := err.(*github.TwoFactorAuthError); ok {
47+
if errors.As(err, new(*github.TwoFactorAuthError)) {
4748
fmt.Print("\nGitHub OTP: ")
4849
otp, _ := r.ReadString('\n')
4950
tp.OTP = strings.TrimSpace(otp)

github/actions_workflow_runs_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package github
88
import (
99
"context"
1010
"encoding/json"
11+
"errors"
1112
"fmt"
1213
"net/http"
1314
"net/url"
@@ -482,7 +483,7 @@ func TestActionsService_CancelWorkflowRunByID(t *testing.T) {
482483

483484
ctx := context.Background()
484485
resp, err := client.Actions.CancelWorkflowRunByID(ctx, "o", "r", 3434)
485-
if _, ok := err.(*AcceptedError); !ok {
486+
if !errors.As(err, new(*AcceptedError)) {
486487
t.Errorf("Actions.CancelWorkflowRunByID returned error: %v (want AcceptedError)", err)
487488
}
488489
if resp.StatusCode != http.StatusAccepted {

github/github.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,8 @@ func (c *Client) bareDo(ctx context.Context, caller *http.Client, req *http.Requ
896896
}
897897

898898
// If the error type is *url.Error, sanitize its URL before returning.
899-
if e, ok := err.(*url.Error); ok {
899+
var e *url.Error
900+
if errors.As(err, &e) {
900901
if url, err := url.Parse(e.URL); err == nil {
901902
e.URL = sanitizeURL(url).String()
902903
return response, e
@@ -923,8 +924,8 @@ func (c *Client) bareDo(ctx context.Context, caller *http.Client, req *http.Requ
923924
// added to the AcceptedError and returned.
924925
//
925926
// Issue #1022
926-
aerr, ok := err.(*AcceptedError)
927-
if ok {
927+
var aerr *AcceptedError
928+
if errors.As(err, &aerr) {
928929
b, readErr := io.ReadAll(resp.Body)
929930
if readErr != nil {
930931
return response, readErr
@@ -934,8 +935,9 @@ func (c *Client) bareDo(ctx context.Context, caller *http.Client, req *http.Requ
934935
err = aerr
935936
}
936937

937-
rateLimitError, ok := err.(*RateLimitError)
938-
if ok && req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
938+
var rateLimitError *RateLimitError
939+
if errors.As(err, &rateLimitError) &&
940+
req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
939941
if err := sleepUntilResetWithBuffer(req.Context(), rateLimitError.Rate.Reset.Time); err != nil {
940942
return response, err
941943
}
@@ -944,8 +946,8 @@ func (c *Client) bareDo(ctx context.Context, caller *http.Client, req *http.Requ
944946
}
945947

946948
// Update the secondary rate limit if we hit it.
947-
rerr, ok := err.(*AbuseRateLimitError)
948-
if ok && rerr.RetryAfter != nil {
949+
var rerr *AbuseRateLimitError
950+
if errors.As(err, &rerr) && rerr.RetryAfter != nil {
949951
// if a max duration is specified, make sure that we are waiting at most this duration
950952
if c.MaxSecondaryRateLimitRetryAfterDuration > 0 && rerr.GetRetryAfter() > c.MaxSecondaryRateLimitRetryAfterDuration {
951953
rerr.RetryAfter = &c.MaxSecondaryRateLimitRetryAfterDuration
@@ -992,8 +994,8 @@ var errInvalidLocation = errors.New("invalid or empty Location header in redirec
992994
func (c *Client) bareDoUntilFound(ctx context.Context, req *http.Request, maxRedirects int) (*url.URL, *Response, error) {
993995
response, err := c.bareDoIgnoreRedirects(ctx, req)
994996
if err != nil {
995-
rerr, ok := err.(*RedirectionError)
996-
if ok {
997+
var rerr *RedirectionError
998+
if errors.As(err, &rerr) {
997999
// If we receive a 302, transform potential relative locations into absolute and return it.
9981000
if rerr.StatusCode == http.StatusFound {
9991001
if rerr.Location == nil {
@@ -1181,8 +1183,8 @@ func (r *ErrorResponse) Error() string {
11811183

11821184
// Is returns whether the provided error equals this error.
11831185
func (r *ErrorResponse) Is(target error) bool {
1184-
v, ok := target.(*ErrorResponse)
1185-
if !ok {
1186+
var v *ErrorResponse
1187+
if !errors.As(target, &v) {
11861188
return false
11871189
}
11881190

@@ -1246,8 +1248,8 @@ func (r *RateLimitError) Error() string {
12461248

12471249
// Is returns whether the provided error equals this error.
12481250
func (r *RateLimitError) Is(target error) bool {
1249-
v, ok := target.(*RateLimitError)
1250-
if !ok {
1251+
var v *RateLimitError
1252+
if !errors.As(target, &v) {
12511253
return false
12521254
}
12531255

@@ -1273,8 +1275,8 @@ func (*AcceptedError) Error() string {
12731275

12741276
// Is returns whether the provided error equals this error.
12751277
func (ae *AcceptedError) Is(target error) bool {
1276-
v, ok := target.(*AcceptedError)
1277-
if !ok {
1278+
var v *AcceptedError
1279+
if !errors.As(target, &v) {
12781280
return false
12791281
}
12801282
return bytes.Equal(ae.Raw, v.Raw)
@@ -1300,8 +1302,8 @@ func (r *AbuseRateLimitError) Error() string {
13001302

13011303
// Is returns whether the provided error equals this error.
13021304
func (r *AbuseRateLimitError) Is(target error) bool {
1303-
v, ok := target.(*AbuseRateLimitError)
1304-
if !ok {
1305+
var v *AbuseRateLimitError
1306+
if !errors.As(target, &v) {
13051307
return false
13061308
}
13071309

@@ -1334,8 +1336,8 @@ func (r *RedirectionError) Error() string {
13341336

13351337
// Is returns whether the provided error equals this error.
13361338
func (r *RedirectionError) Is(target error) bool {
1337-
v, ok := target.(*RedirectionError)
1338-
if !ok {
1339+
var v *RedirectionError
1340+
if !errors.As(target, &v) {
13391341
return false
13401342
}
13411343

@@ -1486,7 +1488,8 @@ func parseBoolResponse(err error) (bool, error) {
14861488
return true, nil
14871489
}
14881490

1489-
if err, ok := err.(*ErrorResponse); ok && err.Response.StatusCode == http.StatusNotFound {
1491+
var rerr *ErrorResponse
1492+
if errors.As(err, &rerr) && rerr.Response.StatusCode == http.StatusNotFound {
14901493
// Simply false. In this one case, we do not pass the error through.
14911494
return false, nil
14921495
}

0 commit comments

Comments
 (0)