Skip to content

Commit 6480126

Browse files
authored
Merge pull request #153 from cockroachdb/jeffswenson-optimize-errors-is
errorbase: optimize errors.Is
2 parents 2008f7c + 3a21e3d commit 6480126

File tree

3 files changed

+246
-77
lines changed

3 files changed

+246
-77
lines changed

benchmark_test.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package errors_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"testing"
8+
9+
"github.com/cockroachdb/errors"
10+
)
11+
12+
func BenchmarkErrorsIs(b *testing.B) {
13+
b.Run("NilError", func(b *testing.B) {
14+
var err error
15+
for range b.N {
16+
errors.Is(err, context.Canceled)
17+
}
18+
})
19+
20+
b.Run("SimpleError", func(b *testing.B) {
21+
err := errors.New("test")
22+
for range b.N {
23+
errors.Is(err, context.Canceled)
24+
}
25+
})
26+
27+
b.Run("WrappedError", func(b *testing.B) {
28+
baseErr := errors.New("test")
29+
err := errors.Wrap(baseErr, "wrapped error")
30+
for range b.N {
31+
errors.Is(err, context.Canceled)
32+
}
33+
})
34+
35+
b.Run("WrappedWithStack", func(b *testing.B) {
36+
baseErr := errors.New("test")
37+
err := errors.WithStack(baseErr)
38+
for range b.N {
39+
errors.Is(err, context.Canceled)
40+
}
41+
})
42+
43+
b.Run("NetworkError", func(b *testing.B) {
44+
netErr := &net.OpError{
45+
Op: "dial",
46+
Net: "tcp",
47+
Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257},
48+
Err: fmt.Errorf("connection refused"),
49+
}
50+
err := errors.Wrap(netErr, "network connection failed")
51+
for range b.N {
52+
errors.Is(err, context.Canceled)
53+
}
54+
})
55+
56+
b.Run("DeeplyWrappedNetworkError", func(b *testing.B) {
57+
netErr := &net.OpError{
58+
Op: "dial",
59+
Net: "tcp",
60+
Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257},
61+
Err: fmt.Errorf("connection refused"),
62+
}
63+
err := errors.WithStack(netErr)
64+
err = errors.Wrap(err, "failed to connect to database")
65+
err = errors.Wrap(err, "unable to establish connection")
66+
err = errors.WithStack(err)
67+
for range b.N {
68+
errors.Is(err, context.Canceled)
69+
}
70+
})
71+
72+
b.Run("MultipleWrappedErrors", func(b *testing.B) {
73+
baseErr := errors.New("internal error")
74+
err := errors.WithStack(baseErr)
75+
err = errors.Wrap(err, "operation failed")
76+
err = errors.WithStack(err)
77+
err = errors.Wrap(err, "transaction failed")
78+
err = errors.WithStack(err)
79+
for range b.N {
80+
errors.Is(err, context.Canceled)
81+
}
82+
})
83+
84+
b.Run("NetworkErrorWithLongAddress", func(b *testing.B) {
85+
netErr := &net.OpError{
86+
Op: "read",
87+
Net: "tcp",
88+
Addr: &net.TCPAddr{
89+
IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
90+
Port: 26257,
91+
},
92+
Err: fmt.Errorf("i/o timeout"),
93+
}
94+
err := errors.WithStack(netErr)
95+
err = errors.Wrap(err, "failed to read from connection")
96+
for range b.N {
97+
errors.Is(err, context.Canceled)
98+
}
99+
})
100+
101+
b.Run("WithMessage", func(b *testing.B) {
102+
baseErr := errors.New("test")
103+
err := errors.WithMessage(baseErr, "additional context")
104+
for range b.N {
105+
errors.Is(err, context.Canceled)
106+
}
107+
})
108+
109+
b.Run("MultipleWithMessage", func(b *testing.B) {
110+
baseErr := errors.New("internal error")
111+
err := errors.WithMessage(baseErr, "first message")
112+
err = errors.WithMessage(err, "second message")
113+
err = errors.WithMessage(err, "third message")
114+
for range b.N {
115+
errors.Is(err, context.Canceled)
116+
}
117+
})
118+
119+
b.Run("WithMessageAndStack", func(b *testing.B) {
120+
baseErr := errors.New("test")
121+
err := errors.WithStack(baseErr)
122+
err = errors.WithMessage(err, "operation context")
123+
err = errors.WithStack(err)
124+
for range b.N {
125+
errors.Is(err, context.Canceled)
126+
}
127+
})
128+
129+
b.Run("NetworkErrorWithMessage", func(b *testing.B) {
130+
netErr := &net.OpError{
131+
Op: "dial",
132+
Net: "tcp",
133+
Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257},
134+
Err: fmt.Errorf("connection refused"),
135+
}
136+
err := errors.WithMessage(netErr, "database connection failed")
137+
err = errors.WithMessage(err, "unable to reach server")
138+
for range b.N {
139+
errors.Is(err, context.Canceled)
140+
}
141+
})
142+
143+
b.Run("NetworkErrorWithEverything", func(b *testing.B) {
144+
netErr := &net.OpError{
145+
Op: "dial",
146+
Net: "tcp",
147+
Addr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 26257},
148+
Err: fmt.Errorf("connection refused"),
149+
}
150+
err := errors.WithStack(netErr)
151+
err = errors.WithMessage(err, "database connection failed")
152+
err = errors.Wrap(err, "failed to establish TCP connection")
153+
err = errors.WithStack(err)
154+
err = errors.WithMessage(err, "unable to reach CockroachDB server")
155+
err = errors.Wrap(err, "connection attempt failed")
156+
for range b.N {
157+
errors.Is(err, context.Canceled)
158+
}
159+
})
160+
161+
b.Run("DeeplyNested100Levels", func(b *testing.B) {
162+
baseErr := errors.New("base error")
163+
err := baseErr
164+
165+
// Create a 100-level deep error chain
166+
for i := 0; i < 100; i++ {
167+
switch i % 3 {
168+
case 0:
169+
err = errors.Wrap(err, fmt.Sprintf("wrap level %d", i))
170+
case 1:
171+
err = errors.WithMessage(err, fmt.Sprintf("message level %d", i))
172+
case 2:
173+
err = errors.WithStack(err)
174+
}
175+
}
176+
177+
for range b.N {
178+
errors.Is(err, context.Canceled)
179+
}
180+
})
181+
}

errbase/encode.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,30 @@ func GetTypeMark(err error) errorspb.ErrorTypeMark {
305305
return errorspb.ErrorTypeMark{FamilyName: familyName, Extension: extension}
306306
}
307307

308+
// EqualTypeMark checks whether `GetTypeMark(e1).Equals(GetTypeMark(e2))`. It
309+
// is written to be be optimized for the case where neither error has
310+
// serialized type information.
311+
func EqualTypeMark(e1, e2 error) bool {
312+
slowPath := func(err error) bool {
313+
switch err.(type) {
314+
case *opaqueLeaf:
315+
return true
316+
case *opaqueLeafCauses:
317+
return true
318+
case *opaqueWrapper:
319+
return true
320+
case TypeKeyMarker:
321+
return true
322+
}
323+
return false
324+
}
325+
if slowPath(e1) || slowPath(e2) {
326+
return GetTypeMark(e1).Equals(GetTypeMark(e2))
327+
}
328+
329+
return reflect.TypeOf(e1) == reflect.TypeOf(e2)
330+
}
331+
308332
// RegisterLeafEncoder can be used to register new leaf error types to
309333
// the library. Registered types will be encoded using their own
310334
// Go type when an error is encoded. Wrappers that have not been
@@ -385,9 +409,7 @@ func RegisterWrapperEncoder(theType TypeKey, encoder WrapperEncoder) {
385409
// Note: if the error type has been migrated from a previous location
386410
// or a different type, ensure that RegisterTypeMigration() was called
387411
// prior to RegisterWrapperEncoder().
388-
func RegisterWrapperEncoderWithMessageType(
389-
theType TypeKey, encoder WrapperEncoderWithMessageType,
390-
) {
412+
func RegisterWrapperEncoderWithMessageType(theType TypeKey, encoder WrapperEncoderWithMessageType) {
391413
if encoder == nil {
392414
delete(encoders, theType)
393415
} else {

markers/markers.go

Lines changed: 40 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -68,29 +68,42 @@ func Is(err, reference error) bool {
6868
}
6969
}
7070

71-
if err == nil {
72-
// Err is nil and reference is non-nil, so it cannot match. We
73-
// want to short-circuit the loop below in this case, otherwise
74-
// we're paying the expense of getMark() without need.
75-
return false
76-
}
77-
78-
// Not directly equal. Try harder, using error marks. We don't do
79-
// this during the loop above as it may be more expensive.
80-
//
81-
// Note: there is a more effective recursive algorithm that ensures
82-
// that any pair of string only gets compared once. Should the
83-
// following code become a performance bottleneck, that algorithm
84-
// can be considered instead.
85-
refMark := getMark(reference)
86-
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
87-
if equalMarks(getMark(c), refMark) {
71+
for errNext := err; errNext != nil; errNext = errbase.UnwrapOnce(errNext) {
72+
if isMarkEqual(errNext, reference) {
8873
return true
8974
}
9075
}
76+
9177
return false
9278
}
9379

80+
func isMarkEqual(err, reference error) bool {
81+
_, errIsMark := err.(*withMark)
82+
_, refIsMark := reference.(*withMark)
83+
if errIsMark || refIsMark {
84+
// If either error is a mark, use the more general
85+
// equalMarks() function.
86+
return equalMarks(getMark(err), getMark(reference))
87+
}
88+
89+
m1 := err
90+
m2 := reference
91+
for m1 != nil && m2 != nil {
92+
if !errbase.EqualTypeMark(m1, m2) {
93+
return false
94+
}
95+
m1 = errbase.UnwrapOnce(m1)
96+
m2 = errbase.UnwrapOnce(m2)
97+
}
98+
99+
// The two chains have different lengths, so they cannot be equal.
100+
if m1 != nil || m2 != nil {
101+
return false
102+
}
103+
104+
return safeGetErrMsg(err) == safeGetErrMsg(reference)
105+
}
106+
94107
func tryDelegateToIsMethod(err, reference error) bool {
95108
if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(reference) {
96109
return true
@@ -150,62 +163,9 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool)
150163
// package location or a different type, ensure that
151164
// RegisterTypeMigration() was called prior to IsAny().
152165
func IsAny(err error, references ...error) bool {
153-
if err == nil {
154-
for _, refErr := range references {
155-
if refErr == nil {
156-
return true
157-
}
158-
}
159-
// The mark-based comparison below will never match anything if
160-
// the error is nil, so don't bother with computing the marks in
161-
// that case. This avoids the computational expense of computing
162-
// the reference marks upfront.
163-
return false
164-
}
165-
166-
// First try using direct reference comparison.
167-
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
168-
for _, refErr := range references {
169-
if refErr == nil {
170-
continue
171-
}
172-
isComparable := reflect.TypeOf(refErr).Comparable()
173-
if isComparable && c == refErr {
174-
return true
175-
}
176-
// Compatibility with std go errors: if the error object itself
177-
// implements Is(), try to use that.
178-
if tryDelegateToIsMethod(c, refErr) {
179-
return true
180-
}
181-
}
182-
183-
// Recursively try multi-error causes, if applicable.
184-
for _, me := range errbase.UnwrapMulti(c) {
185-
if IsAny(me, references...) {
186-
return true
187-
}
188-
}
189-
}
190-
191-
// Try harder with marks.
192-
// Note: there is a more effective recursive algorithm that ensures
193-
// that any pair of string only gets compared once. Should this
194-
// become a performance bottleneck, that algorithm can be considered
195-
// instead.
196-
refMarks := make([]errorMark, 0, len(references))
197-
for _, refErr := range references {
198-
if refErr == nil {
199-
continue
200-
}
201-
refMarks = append(refMarks, getMark(refErr))
202-
}
203-
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
204-
errMark := getMark(c)
205-
for _, refMark := range refMarks {
206-
if equalMarks(errMark, refMark) {
207-
return true
208-
}
166+
for _, reference := range references {
167+
if Is(err, reference) {
168+
return true
209169
}
210170
}
211171
return false
@@ -221,6 +181,9 @@ func equalMarks(m1, m2 errorMark) bool {
221181
if m1.msg != m2.msg {
222182
return false
223183
}
184+
if len(m1.types) != len(m2.types) {
185+
return false
186+
}
224187
for i, t := range m1.types {
225188
if !t.Equals(m2.types[i]) {
226189
return false
@@ -234,7 +197,10 @@ func getMark(err error) errorMark {
234197
if m, ok := err.(*withMark); ok {
235198
return m.mark
236199
}
237-
m := errorMark{msg: safeGetErrMsg(err), types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)}}
200+
m := errorMark{
201+
msg: safeGetErrMsg(err),
202+
types: []errorspb.ErrorTypeMark{errbase.GetTypeMark(err)},
203+
}
238204
for c := errbase.UnwrapOnce(err); c != nil; c = errbase.UnwrapOnce(c) {
239205
m.types = append(m.types, errbase.GetTypeMark(c))
240206
}

0 commit comments

Comments
 (0)