Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/errors/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,36 @@ func ExampleUnwrap() {
// error2: [error1]
// error1
}

func ExampleIsAny() {
if _, err := os.Open("non-existing"); err != nil {
if errors.IsAny(err, fs.ErrNotExist, fs.ErrInvalid) {
fmt.Println("file does not exist")
} else {
fmt.Println(err)
}
}
// Output:
// file does not exist
}

func ExampleMatch() {
_, err := os.Open("non-existing")

matched := errors.Match(err, fs.ErrNotExist, fs.ErrInvalid)
if matched != nil {
fmt.Println("matched error:", matched)
} else {
fmt.Println("no match")
}

switch matched {
case fs.ErrNotExist:
fmt.Println("file does not exist")
case fs.ErrInvalid:
fmt.Println("invalid argument")
}
// Output:
// matched error: file does not exist
// file does not exist
}
95 changes: 95 additions & 0 deletions src/errors/wrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,98 @@ func asType[E error](err error, ppe **E) (_ E, _ bool) {
}
}
}

// IsAny reports whether any error in err's tree matches any of the target errors.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple
// errors, IsAny examines err followed by a depth-first traversal of its children.
func IsAny(err error, targets ...error) bool {
_, found := match(err, targets)

return found
}

// Match returns the first target error from targets that matches any error in err's tree.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple
// errors, Match examines err followed by a depth-first traversal of its children.
//
// Match returns the first target from targets if an err is equal to that target or if
// it implements a method Is(error) bool such that Is(target) returns true.
// If no target matches the err, Match returns nil.
func Match(err error, targets ...error) error {
matched, _ := match(err, targets)

return matched
}

func match(err error, targets []error) (error, bool) {
if err == nil {
for _, target := range targets {
if target == nil {
return nil, true
}
}
return nil, false
}

if len(targets) == 0 {
return nil, false
} else if len(targets) == 1 {
if Is(err, targets[0]) {
return targets[0], true
}

return nil, false
}

targetMap := make(map[error]struct{}, len(targets))
for _, target := range targets {
if target != nil && reflectlite.TypeOf(target).Comparable() {
targetMap[target] = struct{}{}
}
}

return matching(err, targets, targetMap)
}

func matching(err error, targets []error, targetMap map[error]struct{}) (error, bool) {
isErrComparable := reflectlite.TypeOf(err).Comparable()
for {
if isErrComparable && len(targetMap) > 0 {
if _, ok := targetMap[err]; ok {
return err, true
}
}

if x, ok := err.(interface{ Is(error) bool }); ok {
for _, target := range targets {
if target != nil && x.Is(target) {
return target, true
}
}
}

switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return nil, false
}
isErrComparable = reflectlite.TypeOf(err).Comparable()
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if err != nil {
if matched, found := matching(err, targets, targetMap); matched != nil {
return matched, found
}
}
}
return nil, false
default:
return nil, false
}
}
}
214 changes: 214 additions & 0 deletions src/errors/wrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,217 @@ func (errorUncomparable) Is(target error) bool {
_, ok := target.(errorUncomparable)
return ok
}

func TestIsAny(t *testing.T) {
err1 := errors.New("1")
err2 := errors.New("2")
err3 := errors.New("3")
erra := wrapped{"wrap a", err1}
errb := wrapped{"wrap b", err2}

poser := &poser{"either 1 or 3", func(err error) bool {
return err == err1 || err == err3
}}

testCases := []struct {
err error
targets []error
match bool
}{
// Basic cases
{nil, []error{nil}, true},
{nil, []error{err1}, false},
{err1, []error{nil}, false},
{err1, []error{err1}, true},
{err1, []error{err2}, false},
{err1, []error{err1, err2}, true},
{err1, []error{err2, err1}, true},
{err1, []error{err2, err3}, false},

// Wrapped errors
{erra, []error{err1}, true},
{erra, []error{err2}, false},
{erra, []error{err1, err2}, true},
{erra, []error{err2, err1}, true},
{erra, []error{err2, err3}, false},

// Multiple targets with wrapped errors
{errb, []error{err1, err2, err3}, true},
{errb, []error{err1, err3}, false},

// Posers
{poser, []error{err1}, true},
{poser, []error{err3}, true},
{poser, []error{err2}, false},
{poser, []error{err1, err2}, true},
{poser, []error{err2, err3}, true},
{poser, []error{err2, erra}, false},

// Multi errors
{multiErr{}, []error{err1}, false},
{multiErr{err1, err2}, []error{err1}, true},
{multiErr{err1, err2}, []error{err2}, true},
{multiErr{err1, err2}, []error{err3}, false},
{multiErr{err1, err2}, []error{err3, err1}, true},
{multiErr{err1, err2}, []error{err3, erra}, false},
{multiErr{erra, errb}, []error{err1, err2}, true},
{multiErr{erra, errb}, []error{err3, err1}, true},

// Empty targets
{err1, []error{}, false},
{nil, []error{}, false},

// Uncomparable errors
{errorUncomparable{}, []error{errorUncomparable{}}, true},
{&errorUncomparable{}, []error{errorUncomparable{}}, true},
{errorUncomparable{}, []error{err1, errorUncomparable{}}, true},
}

for i, tc := range testCases {
t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
if got := errors.IsAny(tc.err, tc.targets...); got != tc.match {
t.Errorf("IsAny(%v, %v) = %v, want %v", tc.err, tc.targets, got, tc.match)
}
})
}
}

func TestMatch(t *testing.T) {
err1 := errors.New("1")
err2 := errors.New("2")
err3 := errors.New("3")
erra := wrapped{"wrap a", err1}

poser := &poser{"either 1 or 3", func(err error) bool {
return err == err1 || err == err3
}}

testCases := []struct {
err error
targets []error
want error // the expected matched error
}{
{err1, []error{err1}, err1},
{err1, []error{err2}, nil},
{err1, []error{err1, err2}, err1},
{err1, []error{err2, err1}, err1}, // Returns first match (err1)
{err1, []error{err2, err3}, nil},
{erra, []error{err1, err2}, err1},
{erra, []error{err2, err1}, err1}, // erra wraps err1, so matches err1
{erra, []error{err2, err3}, nil},
{nil, []error{nil}, nil},
{nil, []error{err1}, nil},
{err1, []error{}, nil},

// Posers - note that the poser matches err1 or err3
{poser, []error{err1}, err1},
{poser, []error{err3}, err3},
{poser, []error{err2}, nil},
{poser, []error{err2, err1}, err1},
{poser, []error{err1, err3}, err1}, // Returns first match

// Multi errors
{multiErr{err1, err2}, []error{err1}, err1},
{multiErr{err1, err2}, []error{err2}, err2},
{multiErr{err1, err2}, []error{err3}, nil},
{multiErr{err1, err2}, []error{err3, err2}, err2},
}

for i, tc := range testCases {
t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
got := errors.Match(tc.err, tc.targets...)
if got != tc.want {
t.Errorf("Match(%v, %v) = %v, want %v", tc.err, tc.targets, got, tc.want)
}
})
}
}

// isAnySlow is a naive implementation of IsAny for benchmarking purposes.
func isAnySlow(err error, targets ...error) bool {
for _, target := range targets {
if errors.Is(err, target) {
return true
}
}

return false
}

func BenchmarkIsAny(b *testing.B) {
err1 := errors.New("1")
err2 := errors.New("2")
err3 := errors.New("3")
err := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}

testCases := []struct {
name string
fn func(error, ...error) bool
}{
{
name: "IsAny",
fn: errors.IsAny,
},
{
name: "isAnySlow",
fn: isAnySlow,
},
}

for _, tc := range testCases {
b.Run(tc.name+"_one_target", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if !tc.fn(err, err1) {
b.Fatal(tc.name, "failed")
}
}
})

b.Run(tc.name+"three_targets", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if !tc.fn(err, err2, err3, err1) {
b.Fatal(tc.name, "failed")
}
}
})

b.Run(tc.name+"no_match", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if tc.fn(err, err2, err3) {
b.Fatal(tc.name, "should not match")
}
}
})
}
}

func BenchmarkMatch(b *testing.B) {
err1 := errors.New("1")
err2 := errors.New("2")
err3 := errors.New("3")
err := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}

b.Run("one_target", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if errors.Match(err, err1) != err1 {
b.Fatal("Match failed")
}
}
})

b.Run("three_targets", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if errors.Match(err, err2, err3, err1) != err1 {
b.Fatal("Match failed")
}
}
})

b.Run("no_match", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if errors.Match(err, err2, err3) != nil {
b.Fatal("Match should not match")
}
}
})
}