Skip to content

Commit d7ac269

Browse files
authored
Merge pull request #658 from projectdiscovery/dwisiswant0/refactor/errorutil/impl-proper-err-wrapping-w-stdlib-compat
refactor(errorutil): impl proper err wrapping w/ stdlib compat
2 parents d87f17d + 3a47c10 commit d7ac269

File tree

4 files changed

+154
-34
lines changed

4 files changed

+154
-34
lines changed

errors/enriched.go

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package errorutil
22

33
import (
44
"bytes"
5+
"errors"
56
"fmt"
67
"runtime/debug"
78
"strings"
@@ -17,6 +18,7 @@ type ErrCallback func(level ErrorLevel, err string, tags ...string)
1718
// with tags, stacktrace and other methods
1819
type enrichedError struct {
1920
errString string
21+
wrappedErr error
2022
StackTrace string
2123
Tags []string
2224
Level ErrorLevel
@@ -41,6 +43,11 @@ func (e *enrichedError) WithLevel(level ErrorLevel) Error {
4143
return e
4244
}
4345

46+
// Unwrap returns the underlying error
47+
func (e *enrichedError) Unwrap() error {
48+
return e.wrappedErr
49+
}
50+
4451
// returns formated *enrichedError string
4552
func (e *enrichedError) Error() string {
4653
defer func() {
@@ -65,13 +72,33 @@ func (e *enrichedError) Wrap(err ...error) Error {
6572
if v == nil {
6673
continue
6774
}
68-
if ee, ok := v.(*enrichedError); ok {
69-
_ = e.Msgf("%s", ee.errString).WithLevel(ee.Level).WithTag(ee.Tags...)
70-
e.StackTrace += ee.StackTrace
75+
76+
if e.wrappedErr == nil {
77+
e.wrappedErr = v
7178
} else {
72-
_ = e.Msgf("%s", v.Error())
79+
// wraps the existing wrapped error (maintains the error chain)
80+
e.wrappedErr = &enrichedError{
81+
errString: v.Error(),
82+
wrappedErr: e.wrappedErr,
83+
Level: e.Level,
84+
}
85+
}
86+
87+
// preserve its props if it's an enriched one
88+
if ee, ok := v.(*enrichedError); ok {
89+
if len(ee.Tags) > 0 {
90+
if e.Tags == nil {
91+
e.Tags = make([]string, 0)
92+
}
93+
e.Tags = append(e.Tags, ee.Tags...)
94+
}
95+
96+
if ee.StackTrace != "" {
97+
e.StackTrace += ee.StackTrace
98+
}
7399
}
74100
}
101+
75102
return e
76103
}
77104

@@ -95,12 +122,18 @@ func (e *enrichedError) Equal(err ...error) bool {
95122
return true
96123
}
97124
} else {
98-
// not an enriched error but a simple eror
125+
// not an enriched error but a simple error
99126
if e.errString == v.Error() {
100127
return true
101128
}
102129
}
130+
131+
// also check if the err is in the wrapped chain
132+
if errors.Is(e, v) {
133+
return true
134+
}
103135
}
136+
104137
return false
105138
}
106139

@@ -130,11 +163,23 @@ func NewWithErr(err error) Error {
130163
if err == nil {
131164
return nil
132165
}
166+
133167
if ee, ok := err.(*enrichedError); ok {
134-
x := New("%s", ee.errString).WithTag(ee.Tags...).WithLevel(ee.Level)
135-
x.(*enrichedError).StackTrace = ee.StackTrace
168+
return &enrichedError{
169+
errString: ee.errString,
170+
wrappedErr: err,
171+
StackTrace: ee.StackTrace,
172+
Tags: append([]string{}, ee.Tags...),
173+
Level: ee.Level,
174+
OnError: ee.OnError,
175+
}
176+
}
177+
178+
return &enrichedError{
179+
errString: err.Error(),
180+
wrappedErr: err,
181+
Level: Runtime,
136182
}
137-
return New("%s", err.Error())
138183
}
139184

140185
// NewWithTag creates an error with tag

errors/err_test.go

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,46 @@
11
package errorutil_test
22

33
import (
4-
"fmt"
4+
"errors"
55
"strings"
66
"testing"
77

8-
errors "github.com/projectdiscovery/utils/errors"
8+
errorutil "github.com/projectdiscovery/utils/errors"
99
)
1010

11+
type customError struct {
12+
msg string
13+
}
14+
15+
func (c *customError) Error() string {
16+
return c.msg
17+
}
18+
1119
func TestErrorEqual(t *testing.T) {
12-
err1 := fmt.Errorf("error init x")
13-
err2 := errors.NewWithErr(err1)
14-
err3 := errors.NewWithTag("testing", "error init")
20+
err1 := errors.New("error init x")
21+
err2 := errorutil.NewWithErr(err1)
22+
err3 := errorutil.NewWithTag("testing", "error init")
1523
var errnil error
1624

17-
if !errors.IsAny(err1, err2, errnil) {
25+
if !errorutil.IsAny(err1, err2, errnil) {
1826
t.Errorf("expected errors to be equal")
1927
}
20-
if errors.IsAny(err1, err3, errnil) {
28+
if errorutil.IsAny(err1, err3, errnil) {
2129
t.Errorf("expected error to be not equal")
2230
}
2331
}
2432

2533
func TestWrapWithNil(t *testing.T) {
26-
err1 := errors.NewWithTag("niltest", "non nil error").WithLevel(errors.Fatal)
34+
err1 := errorutil.NewWithTag("niltest", "non nil error").WithLevel(errorutil.Fatal)
2735
var errx error
2836

29-
if errors.WrapwithNil(errx, err1) != nil {
37+
if errorutil.WrapwithNil(errx, err1) != nil {
3038
t.Errorf("when base error is nil ")
3139
}
3240
}
3341

3442
func TestStackTrace(t *testing.T) {
35-
err := errors.New("base error")
43+
err := errorutil.New("base error")
3644
relay := func(err error) error {
3745
return err
3846
}
@@ -42,7 +50,7 @@ func TestStackTrace(t *testing.T) {
4250
if strings.Contains(errx.Error(), "captureStack") {
4351
t.Errorf("stacktrace should be disabled by default")
4452
}
45-
errors.ShowStackTrace = true
53+
errorutil.ShowStackTrace = true
4654
if !strings.Contains(errx.Error(), "captureStack") {
4755
t.Errorf("missing stacktrace got %v", errx.Error())
4856
}
@@ -52,8 +60,8 @@ func TestStackTrace(t *testing.T) {
5260
func TestErrorCallback(t *testing.T) {
5361
callbackExecuted := false
5462

55-
err := errors.NewWithTag("callback", "got error").WithCallback(func(level errors.ErrorLevel, err string, tags ...string) {
56-
if level != errors.Runtime {
63+
err := errorutil.NewWithTag("callback", "got error").WithCallback(func(level errorutil.ErrorLevel, err string, tags ...string) {
64+
if level != errorutil.Runtime {
5765
t.Errorf("Default error level should be Runtime")
5866
}
5967
if tags[0] != "callback" {
@@ -72,3 +80,62 @@ func TestErrorCallback(t *testing.T) {
7280
t.Errorf("error callback failed to execute")
7381
}
7482
}
83+
84+
func TestErrorIs(t *testing.T) {
85+
var ErrTest = errors.New("test error")
86+
87+
err := errorutil.NewWithErr(ErrTest).Msgf("message %s", "test")
88+
89+
if !errors.Is(err, ErrTest) {
90+
t.Errorf("expected error to match ErrTest")
91+
}
92+
}
93+
94+
func TestUnwrap(t *testing.T) {
95+
// Test basic unwrapping
96+
baseErr := errors.New("base error")
97+
wrappedErr := errorutil.NewWithErr(baseErr)
98+
99+
if !errors.Is(wrappedErr, baseErr) {
100+
t.Errorf("expected wrapped error to match base error")
101+
}
102+
103+
// Test unwrapping thru error chain
104+
middleErr := errorutil.NewWithErr(baseErr).WithTag("middle")
105+
topErr := errorutil.NewWithErr(middleErr).WithTag("top")
106+
107+
if !errors.Is(topErr, baseErr) {
108+
t.Errorf("expected topErr to match baseErr through chain")
109+
}
110+
111+
if !errors.Is(topErr, middleErr) {
112+
t.Errorf("expected topErr to match middleErr")
113+
}
114+
115+
// Test direct unwrap method
116+
if unwrapped := errors.Unwrap(wrappedErr); unwrapped != baseErr {
117+
t.Errorf("expected direct unwrap to return baseErr, got %v", unwrapped)
118+
}
119+
120+
// Test unwrapping with Wrap method
121+
err1 := errors.New("first error")
122+
err2 := errors.New("second error")
123+
combined := errorutil.New("combined error").Wrap(err1, err2)
124+
125+
if !errors.Is(combined, err1) {
126+
t.Errorf("expected combined error to match err1")
127+
}
128+
129+
// Test errors.As functionality
130+
customErr := &customError{msg: "custom error"}
131+
wrappedCustom := errorutil.NewWithErr(customErr).WithTag("wrapped")
132+
133+
var targetCustom *customError
134+
if !errors.As(wrappedCustom, &targetCustom) {
135+
t.Errorf("expected errors.As to find custom error type")
136+
}
137+
138+
if targetCustom.msg != "custom error" {
139+
t.Errorf("expected custom error message 'custom error', got %s", targetCustom.msg)
140+
}
141+
}

errors/errinterface.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ type Error interface {
99
WithLevel(level ErrorLevel) Error
1010
// Error is interface method of 'error'
1111
Error() string
12+
// Unwrap returns the underlying error
13+
Unwrap() error
1214
// Wraps existing error with errors (skips if passed error is nil)
1315
Wrap(err ...error) Error
1416
// Msgf wraps error with given message

errors/errors.go

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,33 @@ import (
1111

1212
// IsAny checks if err is not nil and matches any one of errxx errors
1313
// if match successful returns true else false
14-
// Note: no unwrapping is done here
1514
func IsAny(err error, errxx ...error) bool {
1615
if err == nil {
1716
return false
1817
}
19-
if enrichedErr, ok := err.(Error); ok {
20-
for _, v := range errxx {
18+
19+
for _, v := range errxx {
20+
if v == nil {
21+
continue
22+
}
23+
24+
// Use stdlib errors.Is for proper err chain traversal
25+
// NOTE(dwisiswant0): Check both directions since either error could
26+
// wrap the other
27+
if errors.Is(err, v) || errors.Is(v, err) {
28+
return true
29+
}
30+
31+
// also check enriched error equality (backward-compatible)
32+
if enrichedErr, ok := err.(Error); ok {
2133
if enrichedErr.Equal(v) {
2234
return true
2335
}
2436
}
25-
} else {
26-
for _, v := range errxx {
27-
// check if v is an enriched error
28-
if ee, ok := v.(Error); ok && ee.Equal(err) {
29-
return true
30-
}
31-
// check standard error equality
32-
if strings.EqualFold(err.Error(), fmt.Sprint(v)) {
33-
return true
34-
}
37+
38+
// fallback to str cmp for non-enriched errors
39+
if strings.EqualFold(err.Error(), fmt.Sprint(v)) {
40+
return true
3541
}
3642
}
3743
return false

0 commit comments

Comments
 (0)