Skip to content

Commit b088062

Browse files
committed
feat(client): prioritize BizStatusError in DefaultClientErrorHandler
1 parent ff85e39 commit b088062

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

client/middlewares.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,17 @@ func newIOErrorHandleMW(errHandle func(context.Context, error) error) endpoint.M
161161

162162
// DefaultClientErrorHandler is Default ErrorHandler for client
163163
// when no ErrorHandler is specified with Option `client.WithErrorHandler`, this ErrorHandler will be injected.
164-
// for thrift、KitexProtobuf, >= v0.4.0 wrap protocol error to TransError, which will be more friendly.
164+
// For thrift、KitexProtobuf >= v0.4.0, wraps protocol error to TransError, which will be more friendly.
165+
// For thrift、KitexProtobuf >= v0.8.1, returns BizStatusError directly if it is set.
165166
func DefaultClientErrorHandler(ctx context.Context, err error) error {
167+
rpcInfo := rpcinfo.GetRPCInfo(ctx)
168+
// If BizStatusErr is not nil, it means that the business logic has been processed and the error has been set
169+
// and transmitted to the client. In this case, just return the bizErr directly.
170+
bizErr := rpcInfo.Invocation().BizStatusErr()
171+
if bizErr != nil {
172+
return bizErr
173+
}
174+
166175
switch err.(type) {
167176
// for thrift、KitexProtobuf, actually check *remote.TransError is enough
168177
case *remote.TransError, thrift.TApplicationException, protobuf.PBError:

client/middlewares_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func TestDefaultErrorHandler(t *testing.T) {
140140
reqCtx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
141141

142142
// Test TApplicationException
143-
err := DefaultClientErrorHandler(context.Background(), thrift.NewTApplicationException(100, "mock"))
143+
err := DefaultClientErrorHandler(reqCtx, thrift.NewTApplicationException(100, "mock"))
144144
test.Assert(t, err.Error() == "remote or network error[remote]: mock", err.Error())
145145
var te thrift.TApplicationException
146146
ok := errors.As(err, &te)
@@ -154,7 +154,7 @@ func TestDefaultErrorHandler(t *testing.T) {
154154
test.Assert(t, te.TypeId() == 100)
155155

156156
// Test PbError
157-
err = DefaultClientErrorHandler(context.Background(), protobuf.NewPbError(100, "mock"))
157+
err = DefaultClientErrorHandler(reqCtx, protobuf.NewPbError(100, "mock"))
158158
test.Assert(t, err.Error() == "remote or network error[remote]: mock")
159159
var pe protobuf.PBError
160160
ok = errors.As(err, &pe)
@@ -168,18 +168,25 @@ func TestDefaultErrorHandler(t *testing.T) {
168168
test.Assert(t, te.TypeId() == 100)
169169

170170
// Test status.Error
171-
err = DefaultClientErrorHandler(context.Background(), status.Err(100, "mock"))
171+
err = DefaultClientErrorHandler(reqCtx, status.Err(100, "mock"))
172172
test.Assert(t, err.Error() == "remote or network error: rpc error: code = 100 desc = mock", err.Error())
173173
// Test status.Error with remote addr
174174
err = ClientErrorHandlerWithAddr(reqCtx, status.Err(100, "mock"))
175175
test.Assert(t, err.Error() == "remote or network error["+tcpAddrStr+"]: rpc error: code = 100 desc = mock", err.Error())
176176

177177
// Test other error
178-
err = DefaultClientErrorHandler(context.Background(), errors.New("mock"))
178+
err = DefaultClientErrorHandler(reqCtx, errors.New("mock"))
179179
test.Assert(t, err.Error() == "remote or network error: mock")
180180
// Test other error with remote addr
181181
err = ClientErrorHandlerWithAddr(reqCtx, errors.New("mock"))
182182
test.Assert(t, err.Error() == "remote or network error["+tcpAddrStr+"]: mock")
183+
184+
// Test BizStatusError set
185+
ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(kerrors.NewBizStatusError(1024, "biz"))
186+
err = DefaultClientErrorHandler(reqCtx, errors.New("mock"))
187+
err, ok = kerrors.FromBizStatusError(err)
188+
test.Assert(t, ok, "should return BizStatusError here")
189+
test.Assert(t, err.Error() == "biz error: code=1024, msg=biz", err.Error())
183190
}
184191

185192
func TestNewProxyMW(t *testing.T) {

0 commit comments

Comments
 (0)