From 9ea83e7d89885e65eae341adf6b717d5421ffbe8 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Thu, 28 Nov 2024 01:11:45 +0530 Subject: [PATCH 1/2] more ptr utils --- env/env.go | 2 +- errkit/errors.go | 43 ++++++++++++++++++++++++++++++------------- ptr/ptr.go | 30 ++++++++++++++++++++++++++++++ ptr/ptr_test.go | 3 ++- 4 files changed, 63 insertions(+), 15 deletions(-) diff --git a/env/env.go b/env/env.go index 69ed094b..20e7b84e 100644 --- a/env/env.go +++ b/env/env.go @@ -25,7 +25,7 @@ func ExpandWithEnv(variables ...*string) { // EnvType is a type that can be used as a type for environment variables. type EnvType interface { - ~string | ~int | ~bool | ~float64 | time.Duration + ~string | ~int | ~bool | ~float64 | time.Duration | ~rune } // GetEnvOrDefault returns the value of the environment variable or the default value if the variable is not set. diff --git a/errkit/errors.go b/errkit/errors.go index f2f7c3b0..30d17d9a 100644 --- a/errkit/errors.go +++ b/errkit/errors.go @@ -7,9 +7,11 @@ import ( "errors" "fmt" "log/slog" + "strconv" "strings" "github.com/projectdiscovery/utils/env" + mapsutil "github.com/projectdiscovery/utils/maps" "golang.org/x/exp/maps" ) @@ -24,14 +26,18 @@ const ( DelimMultiLine = "\n - " // MultiLinePrefix is the prefix used for multiline errors MultiLineErrPrefix = "the following errors occurred:" + // Tab + Tab = '\t' ) var ( // MaxErrorDepth is the maximum depth of errors to be unwrapped or maintained // all errors beyond this depth will be ignored MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3) - // ErrorSeperator is the seperator used to join errors - ErrorSeperator = env.GetEnvOrDefault("ERROR_SEPERATOR", "; ") + // FieldSeperator + ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Tab) + // ErrChainSeperator + ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon) ) // ErrorX is a custom error type that can handle all known types of errors @@ -118,20 +124,25 @@ func (e *ErrorX) Is(err error) bool { // Error returns the error string func (e *ErrorX) Error() string { var sb strings.Builder - if e.kind != nil && e.kind.String() != "" { - sb.WriteString("errKind=") - sb.WriteString(e.kind.String()) - sb.WriteString(" ") - } + sb.WriteString("cause=") + sb.WriteString(strconv.Quote(e.errs[0].Error())) if len(e.attrs) > 0 { - sb.WriteString(slog.GroupValue(maps.Values(e.attrs)...).String()) - sb.WriteString(" ") + values := []string{} + for _, key := range mapsutil.GetSortedKeys(e.attrs) { + values = append(values, key+"="+strconv.Quote(e.attrs[key].String())) + } + sb.WriteRune(ErrFieldSeparator) + sb.WriteString(strings.Join(values, " ")) } - for _, err := range e.errs { - sb.WriteString(err.Error()) - sb.WriteString(ErrorSeperator) + if len(e.errs) > 1 { + chain := []string{} + for _, value := range e.errs[1:] { + chain = append(chain, strings.TrimSpace(value.Error())) + } + sb.WriteRune(ErrFieldSeparator) + sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator))) } - return strings.TrimSuffix(sb.String(), ErrorSeperator) + return sb.String() } // Cause return the original error that caused this without any wrapping @@ -165,6 +176,9 @@ func FromError(err error) *ErrorX { // New creates a new error with the given message func New(format string, args ...interface{}) *ErrorX { e := &ErrorX{} + if len(args) == 0 { + e.append(errors.New(format)) + } e.append(fmt.Errorf(format, args...)) return e } @@ -174,6 +188,9 @@ func (e *ErrorX) Msgf(format string, args ...interface{}) { if e == nil { return } + if len(args) == 0 { + e.append(errors.New(format)) + } e.append(fmt.Errorf(format, args...)) } diff --git a/ptr/ptr.go b/ptr/ptr.go index 902450bc..636cb927 100644 --- a/ptr/ptr.go +++ b/ptr/ptr.go @@ -3,9 +3,39 @@ package ptr // Safe dereferences safely a pointer // - if the pointer is nil => returns the zero value of the type of the pointer if nil // - if the pointer is not nil => returns the dereferenced pointer +// +// Example: +// +// var v *int +// var x = ptr.Safe(v) func Safe[T any](v *T) T { if v == nil { return *new(T) } return *v } + +// Of returns pointer of a given generic type +// +// Example: +// +// var v int +// var p = ptr.Of(v) +func Of[T any](v T) *T { + return &v +} + +// When returns pointer of a given generic type +// - if the condition is false => returns nil +// - if the condition is true => returns pointer of the value +// +// Example: +// +// var v bool +// var p = ptr.When(v, v != false) +func When[T any](v T, condition bool) *T { + if !condition { + return nil + } + return &v +} diff --git a/ptr/ptr_test.go b/ptr/ptr_test.go index 7a7efb83..2131f411 100644 --- a/ptr/ptr_test.go +++ b/ptr/ptr_test.go @@ -1,8 +1,9 @@ package ptr import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestSafe(t *testing.T) { From 82c968b3ccc43ac051261c95c65887ddc049bd81 Mon Sep 17 00:00:00 2001 From: Tarun Koyalwar Date: Thu, 28 Nov 2024 02:13:12 +0530 Subject: [PATCH 2/2] errkit: stick to slog standards + format improvements --- errkit/errors.go | 178 +++++++++++++++++++++++++++++++----------- errkit/errors_test.go | 12 +++ errkit/helpers.go | 14 ++-- errkit/interfaces.go | 6 ++ 4 files changed, 160 insertions(+), 50 deletions(-) diff --git a/errkit/errors.go b/errkit/errors.go index 30d17d9a..f2841010 100644 --- a/errkit/errors.go +++ b/errkit/errors.go @@ -7,12 +7,12 @@ import ( "errors" "fmt" "log/slog" + "runtime" "strconv" "strings" + "time" "github.com/projectdiscovery/utils/env" - mapsutil "github.com/projectdiscovery/utils/maps" - "golang.org/x/exp/maps" ) const ( @@ -26,8 +26,8 @@ const ( DelimMultiLine = "\n - " // MultiLinePrefix is the prefix used for multiline errors MultiLineErrPrefix = "the following errors occurred:" - // Tab - Tab = '\t' + // Space is the identifier used for indentation + Space = " " ) var ( @@ -35,33 +35,69 @@ var ( // all errors beyond this depth will be ignored MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3) // FieldSeperator - ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Tab) + ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Space) // ErrChainSeperator ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon) + // EnableTimestamp controls whether error timestamps are included + EnableTimestamp = env.GetEnvOrDefault("ENABLE_ERR_TIMESTAMP", false) + // EnableTrace controls whether error stack traces are included + EnableTrace = env.GetEnvOrDefault("ENABLE_ERR_TRACE", false) ) // ErrorX is a custom error type that can handle all known types of errors // wrapping and joining strategies including custom ones and it supports error class // which can be shown to client/users in more meaningful way type ErrorX struct { - kind ErrKind - attrs map[string]slog.Attr - errs []error - uniqErrs map[string]struct{} + kind ErrKind + record *slog.Record + source *slog.Source + errs []error +} + +func (e *ErrorX) init(skipStack ...int) { + // initializes if necessary + if e.record == nil { + e.record = &slog.Record{} + if EnableTimestamp { + e.record.Time = time.Now() + } + if EnableTrace { + // get fn name + var pcs [1]uintptr + // skip [runtime.Callers, ErrorX.init, parent] + skip := 3 + if len(skipStack) > 0 { + skip = skipStack[0] + } + runtime.Callers(skip, pcs[:]) + pc := pcs[0] + fs := runtime.CallersFrames([]uintptr{pc}) + f, _ := fs.Next() + e.source = &slog.Source{ + Function: f.Function, + File: f.File, + Line: f.Line, + } + } + } } // append is internal method to append given // error to error slice , it removes duplicates +// earlier it used map which causes more allocations that necessary func (e *ErrorX) append(errs ...error) { - if e.uniqErrs == nil { - e.uniqErrs = make(map[string]struct{}) - } - for _, err := range errs { - if _, ok := e.uniqErrs[err.Error()]; ok { - continue + for _, nerr := range errs { + found := false + new: + for _, oerr := range e.errs { + if oerr.Error() == nerr.Error() { + found = true + break new + } + } + if !found { + e.errs = append(e.errs, nerr) } - e.uniqErrs[err.Error()] = struct{}{} - e.errs = append(e.errs, err) } } @@ -77,8 +113,11 @@ func (e ErrorX) MarshalJSON() ([]byte, error) { "kind": e.kind.String(), "errors": tmp, } - if len(e.attrs) > 0 { - m["attrs"] = slog.GroupValue(maps.Values(e.attrs)...) + if e.record != nil && e.record.NumAttrs() > 0 { + m["attrs"] = slog.GroupValue(e.Attrs()...) + } + if e.source != nil { + m["source"] = e.source } return json.Marshal(m) } @@ -90,10 +129,15 @@ func (e *ErrorX) Errors() []error { // Attrs returns all attributes associated with the error func (e *ErrorX) Attrs() []slog.Attr { - if e.attrs == nil { + if e.record == nil || e.record.NumAttrs() == 0 { return nil } - return maps.Values(e.attrs) + values := []slog.Attr{} + e.record.Attrs(func(a slog.Attr) bool { + values = append(values, a) + return true + }) + return values } // Build returns the object as error interface @@ -109,6 +153,7 @@ func (e *ErrorX) Unwrap() []error { // Is checks if current error contains given error func (e *ErrorX) Is(err error) bool { x := &ErrorX{} + x.init() parseError(x, err) // even one submatch is enough for _, orig := range e.errs { @@ -126,12 +171,13 @@ func (e *ErrorX) Error() string { var sb strings.Builder sb.WriteString("cause=") sb.WriteString(strconv.Quote(e.errs[0].Error())) - if len(e.attrs) > 0 { + if e.record != nil && e.record.NumAttrs() > 0 { values := []string{} - for _, key := range mapsutil.GetSortedKeys(e.attrs) { - values = append(values, key+"="+strconv.Quote(e.attrs[key].String())) - } - sb.WriteRune(ErrFieldSeparator) + e.record.Attrs(func(a slog.Attr) bool { + values = append(values, a.String()) + return true + }) + sb.WriteString(Space) sb.WriteString(strings.Join(values, " ")) } if len(e.errs) > 1 { @@ -139,7 +185,7 @@ func (e *ErrorX) Error() string { for _, value := range e.errs[1:] { chain = append(chain, strings.TrimSpace(value.Error())) } - sb.WriteRune(ErrFieldSeparator) + sb.WriteString(Space) sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator))) } return sb.String() @@ -169,21 +215,47 @@ func FromError(err error) *ErrorX { return nil } nucleiErr := &ErrorX{} + nucleiErr.init() parseError(nucleiErr, err) return nucleiErr } // New creates a new error with the given message -func New(format string, args ...interface{}) *ErrorX { +// it follows slog pattern of adding and expects in the same way +// +// Example: +// +// this is correct (√) +// errkit.New("this is a nuclei error","address",host) +// +// this is not readable/recommended (x) +// errkit.New("this is a nuclei error",slog.String("address",host)) +// +// this is wrong (x) +// errkit.New("this is a nuclei error %s",host) +func New(msg string, args ...interface{}) *ErrorX { e := &ErrorX{} - if len(args) == 0 { - e.append(errors.New(format)) + e.init() + if len(args) > 0 { + e.record.Add(args...) } - e.append(fmt.Errorf(format, args...)) + e.append(errors.New(msg)) return e } // Msgf adds a message to the error +// it follows slog pattern of adding and expects in the same way +// +// Example: +// +// this is correct (√) +// myError.Msgf("dial error","network","tcp") +// +// this is not readable/recommended (x) +// myError.Msgf(slog.String("address",host)) +// +// this is wrong (x) +// myError.Msgf("this is a nuclei error %s",host) func (e *ErrorX) Msgf(format string, args ...interface{}) { if e == nil { return @@ -197,6 +269,11 @@ func (e *ErrorX) Msgf(format string, args ...interface{}) { // SetClass sets the class of the error // if underlying error class was already set, then it is given preference // when generating final error msg +// +// Example: +// +// this is correct (√) +// myError.SetKind(errkit.ErrKindNetworkPermanent) func (e *ErrorX) SetKind(kind ErrKind) *ErrorX { if e.kind == nil { e.kind = kind @@ -206,23 +283,30 @@ func (e *ErrorX) SetKind(kind ErrKind) *ErrorX { return e } +// ResetKind resets the error class of the error +// +// Example: +// +// myError.ResetKind() func (e *ErrorX) ResetKind() *ErrorX { e.kind = nil return e } +// Deprecated: use Attrs instead +// // SetAttr sets additional attributes to a given error // it only adds unique attributes and ignores duplicates // Note: only key is checked for uniqueness +// +// Example: +// +// this is correct (√) +// myError.SetAttr(slog.String("address",host)) func (e *ErrorX) SetAttr(s ...slog.Attr) *ErrorX { + e.init() for _, attr := range s { - if e.attrs == nil { - e.attrs = make(map[string]slog.Attr) - } - // check if this exists - if _, ok := e.attrs[attr.Key]; !ok && len(e.attrs) < MaxErrorDepth { - e.attrs[attr.Key] = attr - } + e.record.Add(attr) } return e } @@ -234,6 +318,7 @@ func parseError(to *ErrorX, err error) { } if to == nil { to = &ErrorX{} + to.init(4) } if len(to.errs) >= MaxErrorDepth { return @@ -242,6 +327,17 @@ func parseError(to *ErrorX, err error) { switch v := err.(type) { case *ErrorX: to.append(v.errs...) + if to.record == nil { + to.record = v.record + } else { + v.record.Attrs(func(a slog.Attr) bool { + to.record.Add(a) + return true + }) + } + if to.source == nil { + to.source = v.source + } to.kind = CombineErrKinds(to.kind, v.kind) case JoinedError: foundAny := false @@ -300,9 +396,3 @@ func parseError(to *ErrorX, err error) { } } } - -// WrappedError is implemented by errors that are wrapped -type WrappedError interface { - // Unwrap returns the underlying error - Unwrap() error -} diff --git a/errkit/errors_test.go b/errkit/errors_test.go index 47c5dad6..a27cb168 100644 --- a/errkit/errors_test.go +++ b/errkit/errors_test.go @@ -114,3 +114,15 @@ func TestMarshalError(t *testing.T) { require.NoError(t, err, "expected to be able to marshal the error") require.Equal(t, `{"errors":["port closed or filtered","this is a wrapped error"],"kind":"network-permanent-error"}`, string(marshalled)) } + +func TestErrorString(t *testing.T) { + var x error = New("i/o timeout") + x = With(x, "ip", "10.0.0.1", "port", 80) + x = WithMessage(x, "tcp dial error") + x = Append(x, errors.New("some other error")) + + require.Equal(t, + `cause="i/o timeout" ip=10.0.0.1 port=80 chain="tcp dial error; some other error"`, + x.Error(), + ) +} diff --git a/errkit/helpers.go b/errkit/helpers.go index d385ad89..8bff426a 100644 --- a/errkit/helpers.go +++ b/errkit/helpers.go @@ -193,19 +193,21 @@ func IsNetworkPermanentErr(err error) bool { return isNetworkPermanentErr(x) } -// WithAttr wraps error with given attributes +// With adds extra attributes to the error // -// err = errkit.WithAttr(err,slog.Any("resource",domain)) -func WithAttr(err error, attrs ...slog.Attr) error { +// err = errkit.With(err,"resource",domain) +func With(err error, args ...any) error { if err == nil { return nil } - if len(attrs) == 0 { + if len(args) == 0 { return err } x := &ErrorX{} + x.init() parseError(x, err) - return x.SetAttr(attrs...) + x.record.Add(args...) + return x } // GetAttr returns all attributes of given error if it has any @@ -271,7 +273,7 @@ func GetAttrValue(err error, key string) slog.Value { } x := &ErrorX{} parseError(x, err) - for _, attr := range x.attrs { + for _, attr := range x.Attrs() { if attr.Key == key { return attr.Value } diff --git a/errkit/interfaces.go b/errkit/interfaces.go index 950efb7b..46c12fb2 100644 --- a/errkit/interfaces.go +++ b/errkit/interfaces.go @@ -30,3 +30,9 @@ type ComparableError interface { // Is checks if current error contains given error Is(err error) bool } + +// WrappedError is implemented by errors that are wrapped +type WrappedError interface { + // Unwrap returns the underlying error + Unwrap() error +}