diff --git a/client/middlewares.go b/client/middlewares.go index 4f94a549c0..6bb2317794 100644 --- a/client/middlewares.go +++ b/client/middlewares.go @@ -161,8 +161,17 @@ func newIOErrorHandleMW(errHandle func(context.Context, error) error) endpoint.M // DefaultClientErrorHandler is Default ErrorHandler for client // when no ErrorHandler is specified with Option `client.WithErrorHandler`, this ErrorHandler will be injected. -// for thrift、KitexProtobuf, >= v0.4.0 wrap protocol error to TransError, which will be more friendly. +// For thrift、KitexProtobuf >= v0.4.0, wraps protocol error to TransError, which will be more friendly. +// For thrift、KitexProtobuf >= v0.8.1, returns BizStatusError directly if it is set. func DefaultClientErrorHandler(ctx context.Context, err error) error { + rpcInfo := rpcinfo.GetRPCInfo(ctx) + // If BizStatusErr is not nil, it means that the business logic has been processed and the error has been set + // and transmitted to the client. In this case, just return the bizErr directly. + bizErr := rpcInfo.Invocation().BizStatusErr() + if bizErr != nil { + return bizErr + } + switch err.(type) { // for thrift、KitexProtobuf, actually check *remote.TransError is enough case *remote.TransError, thrift.TApplicationException, protobuf.PBError: diff --git a/client/middlewares_test.go b/client/middlewares_test.go index a328137853..df4652f90a 100644 --- a/client/middlewares_test.go +++ b/client/middlewares_test.go @@ -140,7 +140,7 @@ func TestDefaultErrorHandler(t *testing.T) { reqCtx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) // Test TApplicationException - err := DefaultClientErrorHandler(context.Background(), thrift.NewTApplicationException(100, "mock")) + err := DefaultClientErrorHandler(reqCtx, thrift.NewTApplicationException(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote]: mock", err.Error()) var te thrift.TApplicationException ok := errors.As(err, &te) @@ -154,7 +154,7 @@ func TestDefaultErrorHandler(t *testing.T) { test.Assert(t, te.TypeId() == 100) // Test PbError - err = DefaultClientErrorHandler(context.Background(), protobuf.NewPbError(100, "mock")) + err = DefaultClientErrorHandler(reqCtx, protobuf.NewPbError(100, "mock")) test.Assert(t, err.Error() == "remote or network error[remote]: mock") var pe protobuf.PBError ok = errors.As(err, &pe) @@ -168,18 +168,25 @@ func TestDefaultErrorHandler(t *testing.T) { test.Assert(t, te.TypeId() == 100) // Test status.Error - err = DefaultClientErrorHandler(context.Background(), status.Err(100, "mock")) + err = DefaultClientErrorHandler(reqCtx, status.Err(100, "mock")) test.Assert(t, err.Error() == "remote or network error: rpc error: code = 100 desc = mock", err.Error()) // Test status.Error with remote addr err = ClientErrorHandlerWithAddr(reqCtx, status.Err(100, "mock")) test.Assert(t, err.Error() == "remote or network error["+tcpAddrStr+"]: rpc error: code = 100 desc = mock", err.Error()) // Test other error - err = DefaultClientErrorHandler(context.Background(), errors.New("mock")) + err = DefaultClientErrorHandler(reqCtx, errors.New("mock")) test.Assert(t, err.Error() == "remote or network error: mock") // Test other error with remote addr err = ClientErrorHandlerWithAddr(reqCtx, errors.New("mock")) test.Assert(t, err.Error() == "remote or network error["+tcpAddrStr+"]: mock") + + // Test BizStatusError set + ri.Invocation().(rpcinfo.InvocationSetter).SetBizStatusErr(kerrors.NewBizStatusError(1024, "biz")) + err = DefaultClientErrorHandler(reqCtx, errors.New("mock")) + err, ok = kerrors.FromBizStatusError(err) + test.Assert(t, ok, "should return BizStatusError here") + test.Assert(t, err.Error() == "biz error: code=1024, msg=biz", err.Error()) } func TestNewProxyMW(t *testing.T) {