diff --git a/errors/enriched.go b/errors/enriched.go index b7a7b5e7..62145390 100644 --- a/errors/enriched.go +++ b/errors/enriched.go @@ -2,6 +2,7 @@ package errorutil import ( "bytes" + "errors" "fmt" "runtime/debug" "strings" @@ -17,6 +18,7 @@ type ErrCallback func(level ErrorLevel, err string, tags ...string) // with tags, stacktrace and other methods type enrichedError struct { errString string + wrappedErr error StackTrace string Tags []string Level ErrorLevel @@ -41,6 +43,11 @@ func (e *enrichedError) WithLevel(level ErrorLevel) Error { return e } +// Unwrap returns the underlying error +func (e *enrichedError) Unwrap() error { + return e.wrappedErr +} + // returns formated *enrichedError string func (e *enrichedError) Error() string { defer func() { @@ -65,13 +72,33 @@ func (e *enrichedError) Wrap(err ...error) Error { if v == nil { continue } - if ee, ok := v.(*enrichedError); ok { - _ = e.Msgf("%s", ee.errString).WithLevel(ee.Level).WithTag(ee.Tags...) - e.StackTrace += ee.StackTrace + + if e.wrappedErr == nil { + e.wrappedErr = v } else { - _ = e.Msgf("%s", v.Error()) + // wraps the existing wrapped error (maintains the error chain) + e.wrappedErr = &enrichedError{ + errString: v.Error(), + wrappedErr: e.wrappedErr, + Level: e.Level, + } + } + + // preserve its props if it's an enriched one + if ee, ok := v.(*enrichedError); ok { + if len(ee.Tags) > 0 { + if e.Tags == nil { + e.Tags = make([]string, 0) + } + e.Tags = append(e.Tags, ee.Tags...) + } + + if ee.StackTrace != "" { + e.StackTrace += ee.StackTrace + } } } + return e } @@ -95,12 +122,18 @@ func (e *enrichedError) Equal(err ...error) bool { return true } } else { - // not an enriched error but a simple eror + // not an enriched error but a simple error if e.errString == v.Error() { return true } } + + // also check if the err is in the wrapped chain + if errors.Is(e, v) { + return true + } } + return false } @@ -130,11 +163,23 @@ func NewWithErr(err error) Error { if err == nil { return nil } + if ee, ok := err.(*enrichedError); ok { - x := New("%s", ee.errString).WithTag(ee.Tags...).WithLevel(ee.Level) - x.(*enrichedError).StackTrace = ee.StackTrace + return &enrichedError{ + errString: ee.errString, + wrappedErr: err, + StackTrace: ee.StackTrace, + Tags: append([]string{}, ee.Tags...), + Level: ee.Level, + OnError: ee.OnError, + } + } + + return &enrichedError{ + errString: err.Error(), + wrappedErr: err, + Level: Runtime, } - return New("%s", err.Error()) } // NewWithTag creates an error with tag diff --git a/errors/err_test.go b/errors/err_test.go index 31cfb93d..6d6341c1 100644 --- a/errors/err_test.go +++ b/errors/err_test.go @@ -1,38 +1,46 @@ package errorutil_test import ( - "fmt" + "errors" "strings" "testing" - errors "github.com/projectdiscovery/utils/errors" + errorutil "github.com/projectdiscovery/utils/errors" ) +type customError struct { + msg string +} + +func (c *customError) Error() string { + return c.msg +} + func TestErrorEqual(t *testing.T) { - err1 := fmt.Errorf("error init x") - err2 := errors.NewWithErr(err1) - err3 := errors.NewWithTag("testing", "error init") + err1 := errors.New("error init x") + err2 := errorutil.NewWithErr(err1) + err3 := errorutil.NewWithTag("testing", "error init") var errnil error - if !errors.IsAny(err1, err2, errnil) { + if !errorutil.IsAny(err1, err2, errnil) { t.Errorf("expected errors to be equal") } - if errors.IsAny(err1, err3, errnil) { + if errorutil.IsAny(err1, err3, errnil) { t.Errorf("expected error to be not equal") } } func TestWrapWithNil(t *testing.T) { - err1 := errors.NewWithTag("niltest", "non nil error").WithLevel(errors.Fatal) + err1 := errorutil.NewWithTag("niltest", "non nil error").WithLevel(errorutil.Fatal) var errx error - if errors.WrapwithNil(errx, err1) != nil { + if errorutil.WrapwithNil(errx, err1) != nil { t.Errorf("when base error is nil ") } } func TestStackTrace(t *testing.T) { - err := errors.New("base error") + err := errorutil.New("base error") relay := func(err error) error { return err } @@ -42,7 +50,7 @@ func TestStackTrace(t *testing.T) { if strings.Contains(errx.Error(), "captureStack") { t.Errorf("stacktrace should be disabled by default") } - errors.ShowStackTrace = true + errorutil.ShowStackTrace = true if !strings.Contains(errx.Error(), "captureStack") { t.Errorf("missing stacktrace got %v", errx.Error()) } @@ -52,8 +60,8 @@ func TestStackTrace(t *testing.T) { func TestErrorCallback(t *testing.T) { callbackExecuted := false - err := errors.NewWithTag("callback", "got error").WithCallback(func(level errors.ErrorLevel, err string, tags ...string) { - if level != errors.Runtime { + err := errorutil.NewWithTag("callback", "got error").WithCallback(func(level errorutil.ErrorLevel, err string, tags ...string) { + if level != errorutil.Runtime { t.Errorf("Default error level should be Runtime") } if tags[0] != "callback" { @@ -72,3 +80,62 @@ func TestErrorCallback(t *testing.T) { t.Errorf("error callback failed to execute") } } + +func TestErrorIs(t *testing.T) { + var ErrTest = errors.New("test error") + + err := errorutil.NewWithErr(ErrTest).Msgf("message %s", "test") + + if !errors.Is(err, ErrTest) { + t.Errorf("expected error to match ErrTest") + } +} + +func TestUnwrap(t *testing.T) { + // Test basic unwrapping + baseErr := errors.New("base error") + wrappedErr := errorutil.NewWithErr(baseErr) + + if !errors.Is(wrappedErr, baseErr) { + t.Errorf("expected wrapped error to match base error") + } + + // Test unwrapping thru error chain + middleErr := errorutil.NewWithErr(baseErr).WithTag("middle") + topErr := errorutil.NewWithErr(middleErr).WithTag("top") + + if !errors.Is(topErr, baseErr) { + t.Errorf("expected topErr to match baseErr through chain") + } + + if !errors.Is(topErr, middleErr) { + t.Errorf("expected topErr to match middleErr") + } + + // Test direct unwrap method + if unwrapped := errors.Unwrap(wrappedErr); unwrapped != baseErr { + t.Errorf("expected direct unwrap to return baseErr, got %v", unwrapped) + } + + // Test unwrapping with Wrap method + err1 := errors.New("first error") + err2 := errors.New("second error") + combined := errorutil.New("combined error").Wrap(err1, err2) + + if !errors.Is(combined, err1) { + t.Errorf("expected combined error to match err1") + } + + // Test errors.As functionality + customErr := &customError{msg: "custom error"} + wrappedCustom := errorutil.NewWithErr(customErr).WithTag("wrapped") + + var targetCustom *customError + if !errors.As(wrappedCustom, &targetCustom) { + t.Errorf("expected errors.As to find custom error type") + } + + if targetCustom.msg != "custom error" { + t.Errorf("expected custom error message 'custom error', got %s", targetCustom.msg) + } +} diff --git a/errors/errinterface.go b/errors/errinterface.go index f7088f80..baf6fa90 100644 --- a/errors/errinterface.go +++ b/errors/errinterface.go @@ -9,6 +9,8 @@ type Error interface { WithLevel(level ErrorLevel) Error // Error is interface method of 'error' Error() string + // Unwrap returns the underlying error + Unwrap() error // Wraps existing error with errors (skips if passed error is nil) Wrap(err ...error) Error // Msgf wraps error with given message diff --git a/errors/errors.go b/errors/errors.go index 8d07b221..1f888b09 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -11,27 +11,33 @@ import ( // IsAny checks if err is not nil and matches any one of errxx errors // if match successful returns true else false -// Note: no unwrapping is done here func IsAny(err error, errxx ...error) bool { if err == nil { return false } - if enrichedErr, ok := err.(Error); ok { - for _, v := range errxx { + + for _, v := range errxx { + if v == nil { + continue + } + + // Use stdlib errors.Is for proper err chain traversal + // NOTE(dwisiswant0): Check both directions since either error could + // wrap the other + if errors.Is(err, v) || errors.Is(v, err) { + return true + } + + // also check enriched error equality (backward-compatible) + if enrichedErr, ok := err.(Error); ok { if enrichedErr.Equal(v) { return true } } - } else { - for _, v := range errxx { - // check if v is an enriched error - if ee, ok := v.(Error); ok && ee.Equal(err) { - return true - } - // check standard error equality - if strings.EqualFold(err.Error(), fmt.Sprint(v)) { - return true - } + + // fallback to str cmp for non-enriched errors + if strings.EqualFold(err.Error(), fmt.Sprint(v)) { + return true } } return false