Skip to content

Commit 8124ee3

Browse files
authored
optimize(ttstream): log the error thrown by invoking handler (#1780)
1 parent c5d41fa commit 8124ee3

File tree

3 files changed

+221
-11
lines changed

3 files changed

+221
-11
lines changed

internal/mocks/remote/servicesearcher.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ func NewDefaultSvcSearcher() *MockSvcSearcher {
4040
mocks.MockExceptionMethod: svcInfo,
4141
mocks.MockErrorMethod: svcInfo,
4242
mocks.MockOnewayMethod: svcInfo,
43+
mocks.MockStreamingMethod: svcInfo,
4344
}
4445
return &MockSvcSearcher{svcMap: s, methodSvcMap: m}
4546
}

pkg/remote/trans/ttstream/server_handler.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727

2828
"github.com/cloudwego/kitex/pkg/endpoint"
2929
"github.com/cloudwego/kitex/pkg/gofunc"
30+
"github.com/cloudwego/kitex/pkg/kerrors"
3031
"github.com/cloudwego/kitex/pkg/klog"
3132
"github.com/cloudwego/kitex/pkg/remote"
3233
"github.com/cloudwego/kitex/pkg/rpcinfo"
@@ -131,7 +132,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
131132
defer wg.Done()
132133
err := t.OnStream(nctx, conn, ss)
133134
if err != nil && !errors.Is(err, io.EOF) {
134-
klog.CtxErrorf(nctx, "KITEX: stream ReadStream failed: err=%v", err)
135+
t.OnError(nctx, err, conn)
135136
}
136137
})
137138
}
@@ -175,9 +176,9 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream
175176
panicErr := recover()
176177
if panicErr != nil {
177178
if conn != nil {
178-
klog.CtxErrorf(ctx, "KITEX: streamx panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack()))
179+
klog.CtxErrorf(ctx, "KITEX: ttstream panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack()))
179180
} else {
180-
klog.CtxErrorf(ctx, "KITEX: streamx panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack()))
181+
klog.CtxErrorf(ctx, "KITEX: ttstream panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack()))
181182
}
182183
}
183184
t.finishTracer(ctx, ri, err, panicErr)
@@ -186,16 +187,18 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, ss stream
186187
args := &streaming.Args{
187188
ServerStream: ss,
188189
}
189-
serr := t.inkHdlFunc(ctx, args, nil)
190-
if serr == nil {
191-
if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil {
192-
serr = bizErr
193-
}
190+
if err = t.inkHdlFunc(ctx, args, nil); err != nil {
191+
// treat err thrown by invoking handler as the final err, ignore the err returned by OnStreamFinish
192+
t.provider.OnStreamFinish(ctx, ss, err)
193+
return
194194
}
195-
ctx, err = t.provider.OnStreamFinish(ctx, ss, serr)
196-
if err == nil && serr != nil {
197-
err = serr
195+
if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil {
196+
// when biz err thrown, treat the err returned by OnStreamFinish as the final err
197+
ctx, err = t.provider.OnStreamFinish(ctx, ss, bizErr)
198+
return
198199
}
200+
// there is no invoking handler err or biz err, treat the err returned by OnStreamFinish as the final err
201+
ctx, err = t.provider.OnStreamFinish(ctx, ss, nil)
199202
return err
200203
}
201204

@@ -211,6 +214,12 @@ func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
211214
}
212215

213216
func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
217+
var de *kerrors.DetailedError
218+
if ok := errors.As(err, &de); ok && de.Stack() != "" {
219+
klog.CtxErrorf(ctx, "KITEX: processing ttstream request error, remoteAddr=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), de.Stack())
220+
} else {
221+
klog.CtxErrorf(ctx, "KITEX: processing ttstream request error, remoteAddr=%s, error=%s", conn.RemoteAddr(), err.Error())
222+
}
214223
}
215224

216225
func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
//go:build !windows
2+
3+
/*
4+
* Copyright 2025 CloudWeGo Authors
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package ttstream
20+
21+
import (
22+
"context"
23+
"errors"
24+
"fmt"
25+
"net"
26+
"runtime/debug"
27+
"testing"
28+
"time"
29+
30+
"github.com/cloudwego/netpoll"
31+
32+
"github.com/cloudwego/kitex/internal/mocks"
33+
mock_remote "github.com/cloudwego/kitex/internal/mocks/remote"
34+
"github.com/cloudwego/kitex/internal/test"
35+
"github.com/cloudwego/kitex/pkg/kerrors"
36+
"github.com/cloudwego/kitex/pkg/remote"
37+
"github.com/cloudwego/kitex/pkg/rpcinfo"
38+
)
39+
40+
type mockNetpollConn struct {
41+
mocks.Conn
42+
reader netpoll.Reader
43+
writer netpoll.Writer
44+
}
45+
46+
func (m *mockNetpollConn) Reader() netpoll.Reader {
47+
return m.reader
48+
}
49+
50+
func (m *mockNetpollConn) Writer() netpoll.Writer {
51+
return m.writer
52+
}
53+
54+
func (m *mockNetpollConn) IsActive() bool {
55+
panic("implement me")
56+
}
57+
58+
func (m *mockNetpollConn) SetReadTimeout(timeout time.Duration) error {
59+
return nil
60+
}
61+
62+
func (m *mockNetpollConn) SetWriteTimeout(timeout time.Duration) error {
63+
return nil
64+
}
65+
66+
func (m *mockNetpollConn) SetIdleTimeout(timeout time.Duration) error {
67+
return nil
68+
}
69+
70+
func (m *mockNetpollConn) SetOnRequest(on netpoll.OnRequest) error {
71+
return nil
72+
}
73+
74+
func (m *mockNetpollConn) AddCloseCallback(callback netpoll.CloseCallback) error {
75+
return nil
76+
}
77+
78+
func (m *mockNetpollConn) WriteFrame(hdr, data []byte) (n int, err error) {
79+
return
80+
}
81+
82+
func (m *mockNetpollConn) ReadFrame() (hdr, data []byte, err error) {
83+
return
84+
}
85+
86+
func (m *mockNetpollConn) SetOnDisconnect(onDisconnect netpoll.OnDisconnect) error {
87+
return nil
88+
}
89+
90+
func TestOnStream(t *testing.T) {
91+
factory := NewSvrTransHandlerFactory()
92+
rawTransHdl, err := factory.NewTransHandler(&remote.ServerOption{
93+
SvcSearcher: mock_remote.NewDefaultSvcSearcher(),
94+
InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
95+
return rpcinfo.NewRPCInfo(nil, nil,
96+
rpcinfo.NewInvocation(mocks.MockServiceName, mocks.MockStreamingMethod),
97+
rpcinfo.NewRPCConfig(),
98+
rpcinfo.NewRPCStats())
99+
},
100+
TracerCtl: &rpcinfo.TraceController{},
101+
})
102+
test.Assert(t, err == nil, err)
103+
transHdl := rawTransHdl.(*svrTransHandler)
104+
105+
rfd, wfd := netpoll.GetSysFdPairs()
106+
rconn, err := netpoll.NewFDConnection(rfd)
107+
test.Assert(t, err == nil, err)
108+
wconn, err := netpoll.NewFDConnection(wfd)
109+
test.Assert(t, err == nil, err)
110+
wbuf := newWriterBuffer(wconn.Writer())
111+
112+
mockConn := &mockNetpollConn{
113+
Conn: mocks.Conn{},
114+
reader: rconn.Reader(),
115+
writer: rconn.Writer(),
116+
}
117+
118+
ctx, aerr := transHdl.OnActive(context.Background(), mockConn)
119+
test.Assert(t, aerr == nil, aerr)
120+
defer func() {
121+
wconn.Close()
122+
_, aerr = transHdl.provider.OnInactive(ctx, mockConn)
123+
test.Assert(t, aerr == nil, aerr)
124+
}()
125+
126+
t.Run("invoking handler successfully", func(t *testing.T) {
127+
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
128+
return nil
129+
})
130+
err = EncodeFrame(context.Background(), wbuf, &Frame{
131+
streamFrame: streamFrame{
132+
sid: 1,
133+
method: mocks.MockStreamingMethod,
134+
},
135+
typ: headerFrameType,
136+
})
137+
test.Assert(t, err == nil, err)
138+
err = wbuf.Flush()
139+
test.Assert(t, err == nil, err)
140+
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
141+
test.Assert(t, err == nil, err)
142+
err = transHdl.OnStream(nctx, mockConn, ss)
143+
test.Assert(t, err == nil, err)
144+
})
145+
146+
t.Run("invoking handler panic", func(t *testing.T) {
147+
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
148+
defer func() {
149+
if handlerErr := recover(); handlerErr != nil {
150+
err = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[panic] %s", handlerErr), string(debug.Stack()))
151+
}
152+
}()
153+
panic("test")
154+
})
155+
err = EncodeFrame(context.Background(), wbuf, &Frame{
156+
streamFrame: streamFrame{
157+
sid: 1,
158+
method: mocks.MockStreamingMethod,
159+
},
160+
typ: headerFrameType,
161+
})
162+
test.Assert(t, err == nil, err)
163+
err = wbuf.Flush()
164+
test.Assert(t, err == nil, err)
165+
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
166+
test.Assert(t, err == nil, err)
167+
err = transHdl.OnStream(nctx, mockConn, ss)
168+
test.Assert(t, errors.Is(err, kerrors.ErrPanic), err)
169+
transHdl.OnError(ctx, err, mockConn)
170+
})
171+
172+
t.Run("invoking handler throws biz error", func(t *testing.T) {
173+
transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
174+
ri := rpcinfo.GetRPCInfo(ctx)
175+
defer func() {
176+
if bizErr, ok := kerrors.FromBizStatusError(err); ok {
177+
if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
178+
setter.SetBizStatusErr(bizErr)
179+
err = nil
180+
}
181+
}
182+
}()
183+
return kerrors.NewBizStatusError(10000, "biz-error test")
184+
})
185+
err = EncodeFrame(context.Background(), wbuf, &Frame{
186+
streamFrame: streamFrame{
187+
sid: 1,
188+
method: mocks.MockStreamingMethod,
189+
},
190+
typ: headerFrameType,
191+
})
192+
test.Assert(t, err == nil, err)
193+
err = wbuf.Flush()
194+
test.Assert(t, err == nil, err)
195+
nctx, ss, err := transHdl.provider.OnStream(ctx, mockConn)
196+
test.Assert(t, err == nil, err)
197+
err = transHdl.OnStream(nctx, mockConn, ss)
198+
test.Assert(t, err == nil, err)
199+
})
200+
}

0 commit comments

Comments
 (0)