Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions errors/enriched.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package errorutil

import (
"bytes"
"errors"
"fmt"
"runtime/debug"
"strings"
Expand All @@ -17,6 +18,7 @@ type ErrCallback func(level ErrorLevel, err string, tags ...string)
// with tags, stacktrace and other methods
type enrichedError struct {
errString string
wrappedErr error
StackTrace string
Tags []string
Level ErrorLevel
Expand All @@ -41,6 +43,11 @@ func (e *enrichedError) WithLevel(level ErrorLevel) Error {
return e
}

// Unwrap returns the underlying error
func (e *enrichedError) Unwrap() error {
return e.wrappedErr
}

// returns formated *enrichedError string
func (e *enrichedError) Error() string {
defer func() {
Expand All @@ -65,13 +72,33 @@ func (e *enrichedError) Wrap(err ...error) Error {
if v == nil {
continue
}
if ee, ok := v.(*enrichedError); ok {
_ = e.Msgf("%s", ee.errString).WithLevel(ee.Level).WithTag(ee.Tags...)
e.StackTrace += ee.StackTrace

if e.wrappedErr == nil {
e.wrappedErr = v
} else {
_ = e.Msgf("%s", v.Error())
// wraps the existing wrapped error (maintains the error chain)
e.wrappedErr = &enrichedError{
errString: v.Error(),
wrappedErr: e.wrappedErr,
Level: e.Level,
}
}

// preserve its props if it's an enriched one
if ee, ok := v.(*enrichedError); ok {
if len(ee.Tags) > 0 {
if e.Tags == nil {
e.Tags = make([]string, 0)
}
e.Tags = append(e.Tags, ee.Tags...)
}

if ee.StackTrace != "" {
e.StackTrace += ee.StackTrace
}
}
}

return e
}

Expand All @@ -95,12 +122,18 @@ func (e *enrichedError) Equal(err ...error) bool {
return true
}
} else {
// not an enriched error but a simple eror
// not an enriched error but a simple error
if e.errString == v.Error() {
return true
}
}

// also check if the err is in the wrapped chain
if errors.Is(e, v) {
return true
}
}

return false
}

Expand Down Expand Up @@ -130,11 +163,23 @@ func NewWithErr(err error) Error {
if err == nil {
return nil
}

if ee, ok := err.(*enrichedError); ok {
x := New("%s", ee.errString).WithTag(ee.Tags...).WithLevel(ee.Level)
x.(*enrichedError).StackTrace = ee.StackTrace
return &enrichedError{
errString: ee.errString,
wrappedErr: err,
StackTrace: ee.StackTrace,
Tags: append([]string{}, ee.Tags...),
Level: ee.Level,
OnError: ee.OnError,
}
}

return &enrichedError{
errString: err.Error(),
wrappedErr: err,
Level: Runtime,
}
return New("%s", err.Error())
}

// NewWithTag creates an error with tag
Expand Down
93 changes: 80 additions & 13 deletions errors/err_test.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,46 @@
package errorutil_test

import (
"fmt"
"errors"
"strings"
"testing"

errors "github.com/projectdiscovery/utils/errors"
errorutil "github.com/projectdiscovery/utils/errors"
)

type customError struct {
msg string
}

func (c *customError) Error() string {
return c.msg
}

func TestErrorEqual(t *testing.T) {
err1 := fmt.Errorf("error init x")
err2 := errors.NewWithErr(err1)
err3 := errors.NewWithTag("testing", "error init")
err1 := errors.New("error init x")
err2 := errorutil.NewWithErr(err1)
err3 := errorutil.NewWithTag("testing", "error init")
var errnil error

if !errors.IsAny(err1, err2, errnil) {
if !errorutil.IsAny(err1, err2, errnil) {
t.Errorf("expected errors to be equal")
}
if errors.IsAny(err1, err3, errnil) {
if errorutil.IsAny(err1, err3, errnil) {
t.Errorf("expected error to be not equal")
}
}

func TestWrapWithNil(t *testing.T) {
err1 := errors.NewWithTag("niltest", "non nil error").WithLevel(errors.Fatal)
err1 := errorutil.NewWithTag("niltest", "non nil error").WithLevel(errorutil.Fatal)
var errx error

if errors.WrapwithNil(errx, err1) != nil {
if errorutil.WrapwithNil(errx, err1) != nil {
t.Errorf("when base error is nil ")
}
}

func TestStackTrace(t *testing.T) {
err := errors.New("base error")
err := errorutil.New("base error")
relay := func(err error) error {
return err
}
Expand All @@ -42,7 +50,7 @@ func TestStackTrace(t *testing.T) {
if strings.Contains(errx.Error(), "captureStack") {
t.Errorf("stacktrace should be disabled by default")
}
errors.ShowStackTrace = true
errorutil.ShowStackTrace = true
if !strings.Contains(errx.Error(), "captureStack") {
t.Errorf("missing stacktrace got %v", errx.Error())
}
Expand All @@ -52,8 +60,8 @@ func TestStackTrace(t *testing.T) {
func TestErrorCallback(t *testing.T) {
callbackExecuted := false

err := errors.NewWithTag("callback", "got error").WithCallback(func(level errors.ErrorLevel, err string, tags ...string) {
if level != errors.Runtime {
err := errorutil.NewWithTag("callback", "got error").WithCallback(func(level errorutil.ErrorLevel, err string, tags ...string) {
if level != errorutil.Runtime {
t.Errorf("Default error level should be Runtime")
}
if tags[0] != "callback" {
Expand All @@ -72,3 +80,62 @@ func TestErrorCallback(t *testing.T) {
t.Errorf("error callback failed to execute")
}
}

func TestErrorIs(t *testing.T) {
var ErrTest = errors.New("test error")

err := errorutil.NewWithErr(ErrTest).Msgf("message %s", "test")

if !errors.Is(err, ErrTest) {
t.Errorf("expected error to match ErrTest")
}
}

func TestUnwrap(t *testing.T) {
// Test basic unwrapping
baseErr := errors.New("base error")
wrappedErr := errorutil.NewWithErr(baseErr)

if !errors.Is(wrappedErr, baseErr) {
t.Errorf("expected wrapped error to match base error")
}

// Test unwrapping thru error chain
middleErr := errorutil.NewWithErr(baseErr).WithTag("middle")
topErr := errorutil.NewWithErr(middleErr).WithTag("top")

if !errors.Is(topErr, baseErr) {
t.Errorf("expected topErr to match baseErr through chain")
}

if !errors.Is(topErr, middleErr) {
t.Errorf("expected topErr to match middleErr")
}

// Test direct unwrap method
if unwrapped := errors.Unwrap(wrappedErr); unwrapped != baseErr {
t.Errorf("expected direct unwrap to return baseErr, got %v", unwrapped)
}

// Test unwrapping with Wrap method
err1 := errors.New("first error")
err2 := errors.New("second error")
combined := errorutil.New("combined error").Wrap(err1, err2)

if !errors.Is(combined, err1) {
t.Errorf("expected combined error to match err1")
}

// Test errors.As functionality
customErr := &customError{msg: "custom error"}
wrappedCustom := errorutil.NewWithErr(customErr).WithTag("wrapped")

var targetCustom *customError
if !errors.As(wrappedCustom, &targetCustom) {
t.Errorf("expected errors.As to find custom error type")
}

if targetCustom.msg != "custom error" {
t.Errorf("expected custom error message 'custom error', got %s", targetCustom.msg)
}
}
2 changes: 2 additions & 0 deletions errors/errinterface.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ type Error interface {
WithLevel(level ErrorLevel) Error
// Error is interface method of 'error'
Error() string
// Unwrap returns the underlying error
Unwrap() error
// Wraps existing error with errors (skips if passed error is nil)
Wrap(err ...error) Error
// Msgf wraps error with given message
Expand Down
32 changes: 19 additions & 13 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,33 @@ import (

// IsAny checks if err is not nil and matches any one of errxx errors
// if match successful returns true else false
// Note: no unwrapping is done here
func IsAny(err error, errxx ...error) bool {
if err == nil {
return false
}
if enrichedErr, ok := err.(Error); ok {
for _, v := range errxx {

for _, v := range errxx {
if v == nil {
continue
}

// Use stdlib errors.Is for proper err chain traversal
// NOTE(dwisiswant0): Check both directions since either error could
// wrap the other
if errors.Is(err, v) || errors.Is(v, err) {
return true
}

// also check enriched error equality (backward-compatible)
if enrichedErr, ok := err.(Error); ok {
if enrichedErr.Equal(v) {
return true
}
}
} else {
for _, v := range errxx {
// check if v is an enriched error
if ee, ok := v.(Error); ok && ee.Equal(err) {
return true
}
// check standard error equality
if strings.EqualFold(err.Error(), fmt.Sprint(v)) {
return true
}

// fallback to str cmp for non-enriched errors
if strings.EqualFold(err.Error(), fmt.Sprint(v)) {
return true
}
}
return false
Expand Down
Loading