diff --git a/src/errors/example_test.go b/src/errors/example_test.go index 92ef36b1010edb..77193a05b95bcf 100644 --- a/src/errors/example_test.go +++ b/src/errors/example_test.go @@ -123,3 +123,36 @@ func ExampleUnwrap() { // error2: [error1] // error1 } + +func ExampleIsAny() { + if _, err := os.Open("non-existing"); err != nil { + if errors.IsAny(err, fs.ErrNotExist, fs.ErrInvalid) { + fmt.Println("file does not exist") + } else { + fmt.Println(err) + } + } + // Output: + // file does not exist +} + +func ExampleMatch() { + _, err := os.Open("non-existing") + + matched := errors.Match(err, fs.ErrNotExist, fs.ErrInvalid) + if matched != nil { + fmt.Println("matched error:", matched) + } else { + fmt.Println("no match") + } + + switch matched { + case fs.ErrNotExist: + fmt.Println("file does not exist") + case fs.ErrInvalid: + fmt.Println("invalid argument") + } + // Output: + // matched error: file does not exist + // file does not exist +} diff --git a/src/errors/wrap.go b/src/errors/wrap.go index 2ebb951f1de93d..87e82600b4873e 100644 --- a/src/errors/wrap.go +++ b/src/errors/wrap.go @@ -206,3 +206,98 @@ func asType[E error](err error, ppe **E) (_ E, _ bool) { } } } + +// IsAny reports whether any error in err's tree matches any of the target errors. +// +// The tree consists of err itself, followed by the errors obtained by repeatedly +// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple +// errors, IsAny examines err followed by a depth-first traversal of its children. +func IsAny(err error, targets ...error) bool { + _, found := match(err, targets) + + return found +} + +// Match returns the first target error from targets that matches any error in err's tree. +// +// The tree consists of err itself, followed by the errors obtained by repeatedly +// calling its Unwrap() error or Unwrap() []error method. When err wraps multiple +// errors, Match examines err followed by a depth-first traversal of its children. +// +// Match returns the first target from targets if an err is equal to that target or if +// it implements a method Is(error) bool such that Is(target) returns true. +// If no target matches the err, Match returns nil. +func Match(err error, targets ...error) error { + matched, _ := match(err, targets) + + return matched +} + +func match(err error, targets []error) (error, bool) { + if err == nil { + for _, target := range targets { + if target == nil { + return nil, true + } + } + return nil, false + } + + if len(targets) == 0 { + return nil, false + } else if len(targets) == 1 { + if Is(err, targets[0]) { + return targets[0], true + } + + return nil, false + } + + targetMap := make(map[error]struct{}, len(targets)) + for _, target := range targets { + if target != nil && reflectlite.TypeOf(target).Comparable() { + targetMap[target] = struct{}{} + } + } + + return matching(err, targets, targetMap) +} + +func matching(err error, targets []error, targetMap map[error]struct{}) (error, bool) { + isErrComparable := reflectlite.TypeOf(err).Comparable() + for { + if isErrComparable && len(targetMap) > 0 { + if _, ok := targetMap[err]; ok { + return err, true + } + } + + if x, ok := err.(interface{ Is(error) bool }); ok { + for _, target := range targets { + if target != nil && x.Is(target) { + return target, true + } + } + } + + switch x := err.(type) { + case interface{ Unwrap() error }: + err = x.Unwrap() + if err == nil { + return nil, false + } + isErrComparable = reflectlite.TypeOf(err).Comparable() + case interface{ Unwrap() []error }: + for _, err := range x.Unwrap() { + if err != nil { + if matched, found := matching(err, targets, targetMap); matched != nil { + return matched, found + } + } + } + return nil, false + default: + return nil, false + } + } +} diff --git a/src/errors/wrap_test.go b/src/errors/wrap_test.go index 81c795a6bb8b18..22da8f0b3ca34e 100644 --- a/src/errors/wrap_test.go +++ b/src/errors/wrap_test.go @@ -436,3 +436,217 @@ func (errorUncomparable) Is(target error) bool { _, ok := target.(errorUncomparable) return ok } + +func TestIsAny(t *testing.T) { + err1 := errors.New("1") + err2 := errors.New("2") + err3 := errors.New("3") + erra := wrapped{"wrap a", err1} + errb := wrapped{"wrap b", err2} + + poser := &poser{"either 1 or 3", func(err error) bool { + return err == err1 || err == err3 + }} + + testCases := []struct { + err error + targets []error + match bool + }{ + // Basic cases + {nil, []error{nil}, true}, + {nil, []error{err1}, false}, + {err1, []error{nil}, false}, + {err1, []error{err1}, true}, + {err1, []error{err2}, false}, + {err1, []error{err1, err2}, true}, + {err1, []error{err2, err1}, true}, + {err1, []error{err2, err3}, false}, + + // Wrapped errors + {erra, []error{err1}, true}, + {erra, []error{err2}, false}, + {erra, []error{err1, err2}, true}, + {erra, []error{err2, err1}, true}, + {erra, []error{err2, err3}, false}, + + // Multiple targets with wrapped errors + {errb, []error{err1, err2, err3}, true}, + {errb, []error{err1, err3}, false}, + + // Posers + {poser, []error{err1}, true}, + {poser, []error{err3}, true}, + {poser, []error{err2}, false}, + {poser, []error{err1, err2}, true}, + {poser, []error{err2, err3}, true}, + {poser, []error{err2, erra}, false}, + + // Multi errors + {multiErr{}, []error{err1}, false}, + {multiErr{err1, err2}, []error{err1}, true}, + {multiErr{err1, err2}, []error{err2}, true}, + {multiErr{err1, err2}, []error{err3}, false}, + {multiErr{err1, err2}, []error{err3, err1}, true}, + {multiErr{err1, err2}, []error{err3, erra}, false}, + {multiErr{erra, errb}, []error{err1, err2}, true}, + {multiErr{erra, errb}, []error{err3, err1}, true}, + + // Empty targets + {err1, []error{}, false}, + {nil, []error{}, false}, + + // Uncomparable errors + {errorUncomparable{}, []error{errorUncomparable{}}, true}, + {&errorUncomparable{}, []error{errorUncomparable{}}, true}, + {errorUncomparable{}, []error{err1, errorUncomparable{}}, true}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) { + if got := errors.IsAny(tc.err, tc.targets...); got != tc.match { + t.Errorf("IsAny(%v, %v) = %v, want %v", tc.err, tc.targets, got, tc.match) + } + }) + } +} + +func TestMatch(t *testing.T) { + err1 := errors.New("1") + err2 := errors.New("2") + err3 := errors.New("3") + erra := wrapped{"wrap a", err1} + + poser := &poser{"either 1 or 3", func(err error) bool { + return err == err1 || err == err3 + }} + + testCases := []struct { + err error + targets []error + want error // the expected matched error + }{ + {err1, []error{err1}, err1}, + {err1, []error{err2}, nil}, + {err1, []error{err1, err2}, err1}, + {err1, []error{err2, err1}, err1}, // Returns first match (err1) + {err1, []error{err2, err3}, nil}, + {erra, []error{err1, err2}, err1}, + {erra, []error{err2, err1}, err1}, // erra wraps err1, so matches err1 + {erra, []error{err2, err3}, nil}, + {nil, []error{nil}, nil}, + {nil, []error{err1}, nil}, + {err1, []error{}, nil}, + + // Posers - note that the poser matches err1 or err3 + {poser, []error{err1}, err1}, + {poser, []error{err3}, err3}, + {poser, []error{err2}, nil}, + {poser, []error{err2, err1}, err1}, + {poser, []error{err1, err3}, err1}, // Returns first match + + // Multi errors + {multiErr{err1, err2}, []error{err1}, err1}, + {multiErr{err1, err2}, []error{err2}, err2}, + {multiErr{err1, err2}, []error{err3}, nil}, + {multiErr{err1, err2}, []error{err3, err2}, err2}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) { + got := errors.Match(tc.err, tc.targets...) + if got != tc.want { + t.Errorf("Match(%v, %v) = %v, want %v", tc.err, tc.targets, got, tc.want) + } + }) + } +} + +// isAnySlow is a naive implementation of IsAny for benchmarking purposes. +func isAnySlow(err error, targets ...error) bool { + for _, target := range targets { + if errors.Is(err, target) { + return true + } + } + + return false +} + +func BenchmarkIsAny(b *testing.B) { + err1 := errors.New("1") + err2 := errors.New("2") + err3 := errors.New("3") + err := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}} + + testCases := []struct { + name string + fn func(error, ...error) bool + }{ + { + name: "IsAny", + fn: errors.IsAny, + }, + { + name: "isAnySlow", + fn: isAnySlow, + }, + } + + for _, tc := range testCases { + b.Run(tc.name+"_one_target", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if !tc.fn(err, err1) { + b.Fatal(tc.name, "failed") + } + } + }) + + b.Run(tc.name+"three_targets", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if !tc.fn(err, err2, err3, err1) { + b.Fatal(tc.name, "failed") + } + } + }) + + b.Run(tc.name+"no_match", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if tc.fn(err, err2, err3) { + b.Fatal(tc.name, "should not match") + } + } + }) + } +} + +func BenchmarkMatch(b *testing.B) { + err1 := errors.New("1") + err2 := errors.New("2") + err3 := errors.New("3") + err := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}} + + b.Run("one_target", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if errors.Match(err, err1) != err1 { + b.Fatal("Match failed") + } + } + }) + + b.Run("three_targets", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if errors.Match(err, err2, err3, err1) != err1 { + b.Fatal("Match failed") + } + } + }) + + b.Run("no_match", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if errors.Match(err, err2, err3) != nil { + b.Fatal("Match should not match") + } + } + }) +}