Skip to content

Commit 850f33d

Browse files
authored
optimize: remove ttstream connection write goroutine to avoid Sender OOM (#1917)
1 parent 910b399 commit 850f33d

9 files changed

Lines changed: 431 additions & 136 deletions

File tree

.typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extend-exclude = ["go.mod", "go.sum"]
66
[default.extend-words]
77
typ = "typ" # type
88
Descritor = "Descritor" # reflect pkg typo, exported func, let it go
9+
consts = "consts" # conventional abbreviation for constants
910

1011
[default.extend-identifiers]
1112
GoAways = "GoAways" # GoAway frame plural noun

pkg/remote/trans/ttstream/frame.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,13 @@ func decodeException(buf []byte) (*gopkgthrift.ApplicationException, error) {
299299
}
300300
return ex, nil
301301
}
302+
303+
func encodeFrameAndFlush(ctx context.Context, writer bufiox.Writer, fr *Frame) (err error) {
304+
if err = EncodeFrame(ctx, writer, fr); err != nil {
305+
return err
306+
}
307+
if err = writer.Flush(); err != nil {
308+
return errTransport.newBuilder().withCause(err)
309+
}
310+
return nil
311+
}

pkg/remote/trans/ttstream/frame_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ func TestFrameCodec(t *testing.T) {
5050
test.DeepEqual(t, string(wframe.payload), string(rframe.payload))
5151
test.DeepEqual(t, wframe.header, rframe.header)
5252
}
53+
54+
for i := 0; i < 10; i++ {
55+
wframe.sid = int32(i)
56+
err = encodeFrameAndFlush(context.Background(), writer, wframe)
57+
test.Assert(t, err == nil, err)
58+
}
59+
err = writer.Flush()
60+
test.Assert(t, err == nil, err)
61+
62+
for i := 0; i < 10; i++ {
63+
rframe, err := DecodeFrame(context.Background(), reader)
64+
test.Assert(t, err == nil, err)
65+
test.DeepEqual(t, string(wframe.payload), string(rframe.payload))
66+
test.DeepEqual(t, wframe.header, rframe.header)
67+
}
5368
}
5469

5570
func TestFrameWithoutPayloadCodec(t *testing.T) {

pkg/remote/trans/ttstream/server_handler.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func (f *svrTransHandlerFactory) NewTransHandler(opts *remote.ServerOption) (rem
8282
var (
8383
_ remote.ServerTransHandler = &svrTransHandler{}
8484
errProtocolNotMatch = errors.New("protocol not match")
85+
errNilTransport = errors.New("server transport is nil")
8586
)
8687

8788
type svrTransHandler struct {
@@ -127,24 +128,20 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.
127128
// OnRead control the connection level lifecycle.
128129
// only when OnRead return, netpoll can close the connection buffer
129130
func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) {
131+
trans, _ := ctx.Value(serverTransCtxKey{}).(*serverTransport)
132+
if trans == nil {
133+
return errNilTransport
134+
}
130135
var wg sync.WaitGroup
131136
defer func() {
132137
wg.Wait()
133-
trans, _ := ctx.Value(serverTransCtxKey{}).(*serverTransport)
134-
if trans != nil {
135-
trans.WaitClosed()
136-
}
138+
trans.WaitClosed()
137139
if errors.Is(err, io.EOF) {
138140
err = nil
139141
}
140142
}()
141143
// connection level goroutine
142144
for {
143-
trans, _ := ctx.Value(serverTransCtxKey{}).(*serverTransport)
144-
if trans == nil {
145-
err = fmt.Errorf("server transport is nil")
146-
return
147-
}
148145
var st *serverStream
149146
// ReadStream will block until a stream coming or conn return error
150147
st, err = trans.ReadStream(ctx)

pkg/remote/trans/ttstream/server_handler_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ func (m *mockTracer) Start(ctx context.Context) context.Context {
101101
}
102102

103103
func (m *mockTracer) Finish(ctx context.Context) {
104-
m.finishFunc(ctx)
104+
if m.finishFunc != nil {
105+
m.finishFunc(ctx)
106+
}
105107
}
106108

107109
type mockHeaderFrameReadHandler struct {

pkg/remote/trans/ttstream/transport.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ const (
3131

3232
streamCacheSize = 32
3333
frameCacheSize = 256
34+
35+
connStateOpen = 0
36+
connStateClosed = 1
3437
)
3538

3639
func isIgnoreError(err error) bool {

pkg/remote/trans/ttstream/transport_client.go

Lines changed: 61 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package ttstream
1818

1919
import (
2020
"context"
21-
"io"
2221
"net"
2322
"sync"
2423
"sync/atomic"
@@ -29,7 +28,6 @@ import (
2928

3029
"github.com/cloudwego/kitex/pkg/gofunc"
3130
"github.com/cloudwego/kitex/pkg/klog"
32-
"github.com/cloudwego/kitex/pkg/remote/trans/ttstream/container"
3331
"github.com/cloudwego/kitex/pkg/streaming"
3432
"github.com/cloudwego/kitex/pkg/utils"
3533
)
@@ -43,13 +41,15 @@ import (
4341
var ticker = utils.NewSyncSharedTicker(5 * time.Second)
4442

4543
type clientTransport struct {
46-
conn netpoll.Connection
47-
pool transPool
48-
streams sync.Map // key=streamID val=clientStream
49-
scache []*clientStream // size is streamCacheSize
50-
spipe *container.Pipe[*clientStream] // in-coming clientStream pipe
51-
fpipe *container.Pipe[*Frame] // out-coming frame pipe
52-
closedFlag int32
44+
conn netpoll.Connection
45+
pool transPool
46+
streams sync.Map // key=streamID val=clientStream
47+
48+
mu sync.Mutex // protect state, closedErr and writer
49+
state int32
50+
closedErr error
51+
writer *writerBuffer
52+
5353
closedTrigger chan struct{}
5454
}
5555

@@ -60,10 +60,8 @@ func newClientTransport(conn netpoll.Connection, pool transPool) *clientTranspor
6060
conn: conn,
6161
pool: pool,
6262
streams: sync.Map{},
63-
spipe: container.NewPipe[*clientStream](),
64-
scache: make([]*clientStream, 0, streamCacheSize),
65-
fpipe: container.NewPipe[*Frame](),
66-
closedTrigger: make(chan struct{}, 2),
63+
writer: newWriterBuffer(conn.Writer()),
64+
closedTrigger: make(chan struct{}, 1),
6765
}
6866
addr := ""
6967
if t.Addr() != nil {
@@ -84,19 +82,6 @@ func newClientTransport(conn netpoll.Connection, pool transPool) *clientTranspor
8482
}()
8583
err = t.loopRead()
8684
}, gofunc.NewBasicInfo("", addr))
87-
gofunc.RecoverGoFuncWithInfo(context.Background(), func() {
88-
var err error
89-
defer func() {
90-
if err != nil {
91-
if !isIgnoreError(err) {
92-
klog.Warnf("clientTransport[%s] loop write err: %v", t.Addr(), err)
93-
}
94-
_ = t.Close(err)
95-
}
96-
t.closedTrigger <- struct{}{}
97-
}()
98-
err = t.loopWrite()
99-
}, gofunc.NewBasicInfo("", addr))
10085

10186
// add to stream cleanup ticker
10287
ticker.Add(t)
@@ -111,35 +96,55 @@ func (t *clientTransport) Addr() net.Addr {
11196
// Close will close transport and destroy all resource and goroutines when transPool discard the transport
11297
// when an exception is encountered and the transport needs to be closed,
11398
// the exception is not nil and the currently surviving streams are aware of this exception.
114-
func (t *clientTransport) Close(exception error) (err error) {
115-
if !atomic.CompareAndSwapInt32(&t.closedFlag, 0, 1) {
116-
return nil
99+
func (t *clientTransport) Close(exception error) error {
100+
t.mu.Lock()
101+
if t.state == connStateClosed {
102+
closedErr := t.closedErr
103+
t.mu.Unlock()
104+
return closedErr
117105
}
106+
t.setClosedStateLocked(exception)
107+
t.mu.Unlock()
108+
109+
t.releaseResources(exception)
110+
111+
return exception
112+
}
113+
114+
// setClosedStateLocked sets the closed state and closed reason.
115+
// Must be called with t.mu held.
116+
func (t *clientTransport) setClosedStateLocked(err error) {
117+
t.state = connStateClosed
118+
t.closedErr = err
119+
}
120+
121+
func (t *clientTransport) releaseResources(err error) {
118122
klog.Debugf("client transport[%s] is closing", t.Addr())
119123
// close streams first
120124
t.streams.Range(func(key, value any) bool {
121125
s := value.(*clientStream)
122-
s.close(exception, false, "", nil)
126+
s.close(err, false, "", nil)
123127
return true
124128
})
125-
// then close stream and frame pipes
126-
t.spipe.Close()
127-
t.fpipe.Close()
129+
130+
if cErr := t.conn.Close(); cErr != nil {
131+
klog.Infof("KITEX: ttstream clientTransport Close Connection failed, err: %v", cErr)
132+
}
128133

129134
// remove cleanup stream task from ticker to avoid goroutine leak
130135
ticker.Delete(t)
131-
132-
return err
133136
}
134137

135138
// WaitClosed waits for send loop and recv loop closed
136139
func (t *clientTransport) WaitClosed() {
137140
<-t.closedTrigger
138-
<-t.closedTrigger
139141
}
140142

141143
func (t *clientTransport) IsActive() bool {
142-
return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive()
144+
t.mu.Lock()
145+
isClosed := t.state == connStateClosed
146+
t.mu.Unlock()
147+
return !isClosed && t.conn.IsActive()
143148
}
144149

145150
func (t *clientTransport) storeStream(s *clientStream) {
@@ -209,39 +214,29 @@ func (t *clientTransport) loopRead() error {
209214
}
210215
}
211216

212-
func (t *clientTransport) loopWrite() error {
217+
// WriteFrame is concurrent safe
218+
func (t *clientTransport) WriteFrame(fr *Frame) (err error) {
219+
var needRelease bool
220+
t.mu.Lock()
213221
defer func() {
214-
// loop write should help to close connection
215-
_ = t.conn.Close()
216-
}()
217-
writer := newWriterBuffer(t.conn.Writer())
218-
fcache := make([]*Frame, frameCacheSize)
219-
// Important note:
220-
// loopWrite may cannot find stream by sid since it may send trailer and delete sid from streams
221-
for {
222-
n, err := t.fpipe.Read(context.Background(), fcache)
223-
if err != nil {
224-
return err
225-
}
226-
if n == 0 {
227-
return io.EOF
228-
}
229-
for i := 0; i < n; i++ {
230-
fr := fcache[i]
231-
if err = EncodeFrame(context.Background(), writer, fr); err != nil {
232-
return err
233-
}
234-
recycleFrame(fr)
235-
}
236-
if err = writer.Flush(); err != nil {
237-
return errTransport.newBuilder().withCause(err)
222+
t.mu.Unlock()
223+
if needRelease {
224+
t.releaseResources(err)
238225
}
226+
}()
227+
if t.state == connStateClosed {
228+
err = t.closedErr
229+
return err
239230
}
240-
}
241231

242-
// WriteFrame is concurrent safe
243-
func (t *clientTransport) WriteFrame(fr *Frame) (err error) {
244-
return t.fpipe.Write(context.Background(), fr)
232+
if err = encodeFrameAndFlush(context.Background(), t.writer, fr); err != nil {
233+
t.setClosedStateLocked(err)
234+
needRelease = true
235+
return err
236+
}
237+
recycleFrame(fr)
238+
239+
return nil
245240
}
246241

247242
func (t *clientTransport) CloseStream(sid int32) (err error) {

0 commit comments

Comments
 (0)