Skip to content

Commit 7ba513a

Browse files
more ptr utils + errkit improvements (#573)
* more ptr utils * errkit: stick to slog standards + format improvements
1 parent cebafa1 commit 7ba513a

File tree

7 files changed

+212
-54
lines changed

7 files changed

+212
-54
lines changed

env/env.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func ExpandWithEnv(variables ...*string) {
2525

2626
// EnvType is a type that can be used as a type for environment variables.
2727
type EnvType interface {
28-
~string | ~int | ~bool | ~float64 | time.Duration
28+
~string | ~int | ~bool | ~float64 | time.Duration | ~rune
2929
}
3030

3131
// GetEnvOrDefault returns the value of the environment variable or the default value if the variable is not set.

errkit/errors.go

Lines changed: 153 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"runtime"
11+
"strconv"
1012
"strings"
13+
"time"
1114

1215
"github.com/projectdiscovery/utils/env"
13-
"golang.org/x/exp/maps"
1416
)
1517

1618
const (
@@ -24,38 +26,78 @@ const (
2426
DelimMultiLine = "\n - "
2527
// MultiLinePrefix is the prefix used for multiline errors
2628
MultiLineErrPrefix = "the following errors occurred:"
29+
// Space is the identifier used for indentation
30+
Space = " "
2731
)
2832

2933
var (
3034
// MaxErrorDepth is the maximum depth of errors to be unwrapped or maintained
3135
// all errors beyond this depth will be ignored
3236
MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3)
33-
// ErrorSeperator is the seperator used to join errors
34-
ErrorSeperator = env.GetEnvOrDefault("ERROR_SEPERATOR", "; ")
37+
// FieldSeperator
38+
ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Space)
39+
// ErrChainSeperator
40+
ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon)
41+
// EnableTimestamp controls whether error timestamps are included
42+
EnableTimestamp = env.GetEnvOrDefault("ENABLE_ERR_TIMESTAMP", false)
43+
// EnableTrace controls whether error stack traces are included
44+
EnableTrace = env.GetEnvOrDefault("ENABLE_ERR_TRACE", false)
3545
)
3646

3747
// ErrorX is a custom error type that can handle all known types of errors
3848
// wrapping and joining strategies including custom ones and it supports error class
3949
// which can be shown to client/users in more meaningful way
4050
type ErrorX struct {
41-
kind ErrKind
42-
attrs map[string]slog.Attr
43-
errs []error
44-
uniqErrs map[string]struct{}
51+
kind ErrKind
52+
record *slog.Record
53+
source *slog.Source
54+
errs []error
55+
}
56+
57+
func (e *ErrorX) init(skipStack ...int) {
58+
// initializes if necessary
59+
if e.record == nil {
60+
e.record = &slog.Record{}
61+
if EnableTimestamp {
62+
e.record.Time = time.Now()
63+
}
64+
if EnableTrace {
65+
// get fn name
66+
var pcs [1]uintptr
67+
// skip [runtime.Callers, ErrorX.init, parent]
68+
skip := 3
69+
if len(skipStack) > 0 {
70+
skip = skipStack[0]
71+
}
72+
runtime.Callers(skip, pcs[:])
73+
pc := pcs[0]
74+
fs := runtime.CallersFrames([]uintptr{pc})
75+
f, _ := fs.Next()
76+
e.source = &slog.Source{
77+
Function: f.Function,
78+
File: f.File,
79+
Line: f.Line,
80+
}
81+
}
82+
}
4583
}
4684

4785
// append is internal method to append given
4886
// error to error slice , it removes duplicates
87+
// earlier it used map which causes more allocations that necessary
4988
func (e *ErrorX) append(errs ...error) {
50-
if e.uniqErrs == nil {
51-
e.uniqErrs = make(map[string]struct{})
52-
}
53-
for _, err := range errs {
54-
if _, ok := e.uniqErrs[err.Error()]; ok {
55-
continue
89+
for _, nerr := range errs {
90+
found := false
91+
new:
92+
for _, oerr := range e.errs {
93+
if oerr.Error() == nerr.Error() {
94+
found = true
95+
break new
96+
}
97+
}
98+
if !found {
99+
e.errs = append(e.errs, nerr)
56100
}
57-
e.uniqErrs[err.Error()] = struct{}{}
58-
e.errs = append(e.errs, err)
59101
}
60102
}
61103

@@ -71,8 +113,11 @@ func (e ErrorX) MarshalJSON() ([]byte, error) {
71113
"kind": e.kind.String(),
72114
"errors": tmp,
73115
}
74-
if len(e.attrs) > 0 {
75-
m["attrs"] = slog.GroupValue(maps.Values(e.attrs)...)
116+
if e.record != nil && e.record.NumAttrs() > 0 {
117+
m["attrs"] = slog.GroupValue(e.Attrs()...)
118+
}
119+
if e.source != nil {
120+
m["source"] = e.source
76121
}
77122
return json.Marshal(m)
78123
}
@@ -84,10 +129,15 @@ func (e *ErrorX) Errors() []error {
84129

85130
// Attrs returns all attributes associated with the error
86131
func (e *ErrorX) Attrs() []slog.Attr {
87-
if e.attrs == nil {
132+
if e.record == nil || e.record.NumAttrs() == 0 {
88133
return nil
89134
}
90-
return maps.Values(e.attrs)
135+
values := []slog.Attr{}
136+
e.record.Attrs(func(a slog.Attr) bool {
137+
values = append(values, a)
138+
return true
139+
})
140+
return values
91141
}
92142

93143
// Build returns the object as error interface
@@ -103,6 +153,7 @@ func (e *ErrorX) Unwrap() []error {
103153
// Is checks if current error contains given error
104154
func (e *ErrorX) Is(err error) bool {
105155
x := &ErrorX{}
156+
x.init()
106157
parseError(x, err)
107158
// even one submatch is enough
108159
for _, orig := range e.errs {
@@ -118,20 +169,26 @@ func (e *ErrorX) Is(err error) bool {
118169
// Error returns the error string
119170
func (e *ErrorX) Error() string {
120171
var sb strings.Builder
121-
if e.kind != nil && e.kind.String() != "" {
122-
sb.WriteString("errKind=")
123-
sb.WriteString(e.kind.String())
124-
sb.WriteString(" ")
125-
}
126-
if len(e.attrs) > 0 {
127-
sb.WriteString(slog.GroupValue(maps.Values(e.attrs)...).String())
128-
sb.WriteString(" ")
172+
sb.WriteString("cause=")
173+
sb.WriteString(strconv.Quote(e.errs[0].Error()))
174+
if e.record != nil && e.record.NumAttrs() > 0 {
175+
values := []string{}
176+
e.record.Attrs(func(a slog.Attr) bool {
177+
values = append(values, a.String())
178+
return true
179+
})
180+
sb.WriteString(Space)
181+
sb.WriteString(strings.Join(values, " "))
129182
}
130-
for _, err := range e.errs {
131-
sb.WriteString(err.Error())
132-
sb.WriteString(ErrorSeperator)
183+
if len(e.errs) > 1 {
184+
chain := []string{}
185+
for _, value := range e.errs[1:] {
186+
chain = append(chain, strings.TrimSpace(value.Error()))
187+
}
188+
sb.WriteString(Space)
189+
sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator)))
133190
}
134-
return strings.TrimSuffix(sb.String(), ErrorSeperator)
191+
return sb.String()
135192
}
136193

137194
// Cause return the original error that caused this without any wrapping
@@ -158,28 +215,65 @@ func FromError(err error) *ErrorX {
158215
return nil
159216
}
160217
nucleiErr := &ErrorX{}
218+
nucleiErr.init()
161219
parseError(nucleiErr, err)
162220
return nucleiErr
163221
}
164222

165223
// New creates a new error with the given message
166-
func New(format string, args ...interface{}) *ErrorX {
224+
// it follows slog pattern of adding and expects in the same way
225+
//
226+
// Example:
227+
//
228+
// this is correct (√)
229+
// errkit.New("this is a nuclei error","address",host)
230+
//
231+
// this is not readable/recommended (x)
232+
// errkit.New("this is a nuclei error",slog.String("address",host))
233+
//
234+
// this is wrong (x)
235+
// errkit.New("this is a nuclei error %s",host)
236+
func New(msg string, args ...interface{}) *ErrorX {
167237
e := &ErrorX{}
168-
e.append(fmt.Errorf(format, args...))
238+
e.init()
239+
if len(args) > 0 {
240+
e.record.Add(args...)
241+
}
242+
e.append(errors.New(msg))
169243
return e
170244
}
171245

172246
// Msgf adds a message to the error
247+
// it follows slog pattern of adding and expects in the same way
248+
//
249+
// Example:
250+
//
251+
// this is correct (√)
252+
// myError.Msgf("dial error","network","tcp")
253+
//
254+
// this is not readable/recommended (x)
255+
// myError.Msgf(slog.String("address",host))
256+
//
257+
// this is wrong (x)
258+
// myError.Msgf("this is a nuclei error %s",host)
173259
func (e *ErrorX) Msgf(format string, args ...interface{}) {
174260
if e == nil {
175261
return
176262
}
263+
if len(args) == 0 {
264+
e.append(errors.New(format))
265+
}
177266
e.append(fmt.Errorf(format, args...))
178267
}
179268

180269
// SetClass sets the class of the error
181270
// if underlying error class was already set, then it is given preference
182271
// when generating final error msg
272+
//
273+
// Example:
274+
//
275+
// this is correct (√)
276+
// myError.SetKind(errkit.ErrKindNetworkPermanent)
183277
func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
184278
if e.kind == nil {
185279
e.kind = kind
@@ -189,23 +283,30 @@ func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
189283
return e
190284
}
191285

286+
// ResetKind resets the error class of the error
287+
//
288+
// Example:
289+
//
290+
// myError.ResetKind()
192291
func (e *ErrorX) ResetKind() *ErrorX {
193292
e.kind = nil
194293
return e
195294
}
196295

296+
// Deprecated: use Attrs instead
297+
//
197298
// SetAttr sets additional attributes to a given error
198299
// it only adds unique attributes and ignores duplicates
199300
// Note: only key is checked for uniqueness
301+
//
302+
// Example:
303+
//
304+
// this is correct (√)
305+
// myError.SetAttr(slog.String("address",host))
200306
func (e *ErrorX) SetAttr(s ...slog.Attr) *ErrorX {
307+
e.init()
201308
for _, attr := range s {
202-
if e.attrs == nil {
203-
e.attrs = make(map[string]slog.Attr)
204-
}
205-
// check if this exists
206-
if _, ok := e.attrs[attr.Key]; !ok && len(e.attrs) < MaxErrorDepth {
207-
e.attrs[attr.Key] = attr
208-
}
309+
e.record.Add(attr)
209310
}
210311
return e
211312
}
@@ -217,6 +318,7 @@ func parseError(to *ErrorX, err error) {
217318
}
218319
if to == nil {
219320
to = &ErrorX{}
321+
to.init(4)
220322
}
221323
if len(to.errs) >= MaxErrorDepth {
222324
return
@@ -225,6 +327,17 @@ func parseError(to *ErrorX, err error) {
225327
switch v := err.(type) {
226328
case *ErrorX:
227329
to.append(v.errs...)
330+
if to.record == nil {
331+
to.record = v.record
332+
} else {
333+
v.record.Attrs(func(a slog.Attr) bool {
334+
to.record.Add(a)
335+
return true
336+
})
337+
}
338+
if to.source == nil {
339+
to.source = v.source
340+
}
228341
to.kind = CombineErrKinds(to.kind, v.kind)
229342
case JoinedError:
230343
foundAny := false
@@ -283,9 +396,3 @@ func parseError(to *ErrorX, err error) {
283396
}
284397
}
285398
}
286-
287-
// WrappedError is implemented by errors that are wrapped
288-
type WrappedError interface {
289-
// Unwrap returns the underlying error
290-
Unwrap() error
291-
}

errkit/errors_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,15 @@ func TestMarshalError(t *testing.T) {
114114
require.NoError(t, err, "expected to be able to marshal the error")
115115
require.Equal(t, `{"errors":["port closed or filtered","this is a wrapped error"],"kind":"network-permanent-error"}`, string(marshalled))
116116
}
117+
118+
func TestErrorString(t *testing.T) {
119+
var x error = New("i/o timeout")
120+
x = With(x, "ip", "10.0.0.1", "port", 80)
121+
x = WithMessage(x, "tcp dial error")
122+
x = Append(x, errors.New("some other error"))
123+
124+
require.Equal(t,
125+
`cause="i/o timeout" ip=10.0.0.1 port=80 chain="tcp dial error; some other error"`,
126+
x.Error(),
127+
)
128+
}

errkit/helpers.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,21 @@ func IsNetworkPermanentErr(err error) bool {
193193
return isNetworkPermanentErr(x)
194194
}
195195

196-
// WithAttr wraps error with given attributes
196+
// With adds extra attributes to the error
197197
//
198-
// err = errkit.WithAttr(err,slog.Any("resource",domain))
199-
func WithAttr(err error, attrs ...slog.Attr) error {
198+
// err = errkit.With(err,"resource",domain)
199+
func With(err error, args ...any) error {
200200
if err == nil {
201201
return nil
202202
}
203-
if len(attrs) == 0 {
203+
if len(args) == 0 {
204204
return err
205205
}
206206
x := &ErrorX{}
207+
x.init()
207208
parseError(x, err)
208-
return x.SetAttr(attrs...)
209+
x.record.Add(args...)
210+
return x
209211
}
210212

211213
// GetAttr returns all attributes of given error if it has any
@@ -271,7 +273,7 @@ func GetAttrValue(err error, key string) slog.Value {
271273
}
272274
x := &ErrorX{}
273275
parseError(x, err)
274-
for _, attr := range x.attrs {
276+
for _, attr := range x.Attrs() {
275277
if attr.Key == key {
276278
return attr.Value
277279
}

errkit/interfaces.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ type ComparableError interface {
3030
// Is checks if current error contains given error
3131
Is(err error) bool
3232
}
33+
34+
// WrappedError is implemented by errors that are wrapped
35+
type WrappedError interface {
36+
// Unwrap returns the underlying error
37+
Unwrap() error
38+
}

0 commit comments

Comments
 (0)