@@ -68,29 +68,42 @@ func Is(err, reference error) bool {
6868 }
6969 }
7070
71- if err == nil {
72- // Err is nil and reference is non-nil, so it cannot match. We
73- // want to short-circuit the loop below in this case, otherwise
74- // we're paying the expense of getMark() without need.
75- return false
76- }
77-
78- // Not directly equal. Try harder, using error marks. We don't do
79- // this during the loop above as it may be more expensive.
80- //
81- // Note: there is a more effective recursive algorithm that ensures
82- // that any pair of string only gets compared once. Should the
83- // following code become a performance bottleneck, that algorithm
84- // can be considered instead.
85- refMark := getMark (reference )
86- for c := err ; c != nil ; c = errbase .UnwrapOnce (c ) {
87- if equalMarks (getMark (c ), refMark ) {
71+ for errNext := err ; errNext != nil ; errNext = errbase .UnwrapOnce (errNext ) {
72+ if isMarkEqual (errNext , reference ) {
8873 return true
8974 }
9075 }
76+
9177 return false
9278}
9379
80+ func isMarkEqual (err , reference error ) bool {
81+ _ , errIsMark := err .(* withMark )
82+ _ , refIsMark := reference .(* withMark )
83+ if errIsMark || refIsMark {
84+ // If either error is a mark, use the more general
85+ // equalMarks() function.
86+ return equalMarks (getMark (err ), getMark (reference ))
87+ }
88+
89+ m1 := err
90+ m2 := reference
91+ for m1 != nil && m2 != nil {
92+ if ! errbase .EqualTypeMark (m1 , m2 ) {
93+ return false
94+ }
95+ m1 = errbase .UnwrapOnce (m1 )
96+ m2 = errbase .UnwrapOnce (m2 )
97+ }
98+
99+ // The two chains have different lengths, so they cannot be equal.
100+ if m1 != nil || m2 != nil {
101+ return false
102+ }
103+
104+ return safeGetErrMsg (err ) == safeGetErrMsg (reference )
105+ }
106+
94107func tryDelegateToIsMethod (err , reference error ) bool {
95108 if x , ok := err .(interface { Is (error ) bool }); ok && x .Is (reference ) {
96109 return true
@@ -150,62 +163,9 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool)
150163// package location or a different type, ensure that
151164// RegisterTypeMigration() was called prior to IsAny().
152165func IsAny (err error , references ... error ) bool {
153- if err == nil {
154- for _ , refErr := range references {
155- if refErr == nil {
156- return true
157- }
158- }
159- // The mark-based comparison below will never match anything if
160- // the error is nil, so don't bother with computing the marks in
161- // that case. This avoids the computational expense of computing
162- // the reference marks upfront.
163- return false
164- }
165-
166- // First try using direct reference comparison.
167- for c := err ; c != nil ; c = errbase .UnwrapOnce (c ) {
168- for _ , refErr := range references {
169- if refErr == nil {
170- continue
171- }
172- isComparable := reflect .TypeOf (refErr ).Comparable ()
173- if isComparable && c == refErr {
174- return true
175- }
176- // Compatibility with std go errors: if the error object itself
177- // implements Is(), try to use that.
178- if tryDelegateToIsMethod (c , refErr ) {
179- return true
180- }
181- }
182-
183- // Recursively try multi-error causes, if applicable.
184- for _ , me := range errbase .UnwrapMulti (c ) {
185- if IsAny (me , references ... ) {
186- return true
187- }
188- }
189- }
190-
191- // Try harder with marks.
192- // Note: there is a more effective recursive algorithm that ensures
193- // that any pair of string only gets compared once. Should this
194- // become a performance bottleneck, that algorithm can be considered
195- // instead.
196- refMarks := make ([]errorMark , 0 , len (references ))
197- for _ , refErr := range references {
198- if refErr == nil {
199- continue
200- }
201- refMarks = append (refMarks , getMark (refErr ))
202- }
203- for c := err ; c != nil ; c = errbase .UnwrapOnce (c ) {
204- errMark := getMark (c )
205- for _ , refMark := range refMarks {
206- if equalMarks (errMark , refMark ) {
207- return true
208- }
166+ for _ , reference := range references {
167+ if Is (err , reference ) {
168+ return true
209169 }
210170 }
211171 return false
@@ -221,6 +181,9 @@ func equalMarks(m1, m2 errorMark) bool {
221181 if m1 .msg != m2 .msg {
222182 return false
223183 }
184+ if len (m1 .types ) != len (m2 .types ) {
185+ return false
186+ }
224187 for i , t := range m1 .types {
225188 if ! t .Equals (m2 .types [i ]) {
226189 return false
@@ -234,7 +197,10 @@ func getMark(err error) errorMark {
234197 if m , ok := err .(* withMark ); ok {
235198 return m .mark
236199 }
237- m := errorMark {msg : safeGetErrMsg (err ), types : []errorspb.ErrorTypeMark {errbase .GetTypeMark (err )}}
200+ m := errorMark {
201+ msg : safeGetErrMsg (err ),
202+ types : []errorspb.ErrorTypeMark {errbase .GetTypeMark (err )},
203+ }
238204 for c := errbase .UnwrapOnce (err ); c != nil ; c = errbase .UnwrapOnce (c ) {
239205 m .types = append (m .types , errbase .GetTypeMark (c ))
240206 }
0 commit comments