Skip to content

Commit 9f5dcb1

Browse files
authored
MUX: Prevent goroutine leak (#5110)
1 parent ce5c51d commit 9f5dcb1

File tree

7 files changed

+109
-33
lines changed

7 files changed

+109
-33
lines changed

app/reverse/bridge.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/xtls/xray-core/common/mux"
1010
"github.com/xtls/xray-core/common/net"
1111
"github.com/xtls/xray-core/common/session"
12+
"github.com/xtls/xray-core/common/signal"
1213
"github.com/xtls/xray-core/common/task"
1314
"github.com/xtls/xray-core/features/routing"
1415
"github.com/xtls/xray-core/transport"
@@ -53,6 +54,9 @@ func (b *Bridge) cleanup() {
5354
if w.IsActive() {
5455
activeWorkers = append(activeWorkers, w)
5556
}
57+
if w.Closed() {
58+
w.Timer.SetTimeout(0)
59+
}
5660
}
5761

5862
if len(activeWorkers) != len(b.workers) {
@@ -98,6 +102,7 @@ type BridgeWorker struct {
98102
Worker *mux.ServerWorker
99103
Dispatcher routing.Dispatcher
100104
State Control_State
105+
Timer *signal.ActivityTimer
101106
}
102107

103108
func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
@@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
125130
}
126131
w.Worker = worker
127132

133+
terminate := func() {
134+
worker.Close()
135+
}
136+
w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second)
128137
return w, nil
129138
}
130139

@@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool {
144153
return w.State == Control_ACTIVE && !w.Worker.Closed()
145154
}
146155

156+
func (w *BridgeWorker) Closed() bool {
157+
return w.Worker.Closed()
158+
}
159+
147160
func (w *BridgeWorker) Connections() uint32 {
148161
return w.Worker.ActiveConnections()
149162
}
@@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
153166
for {
154167
mb, err := reader.ReadMultiBuffer()
155168
if err != nil {
156-
break
169+
if w.Closed() {
170+
w.Timer.SetTimeout(0)
171+
} else {
172+
w.Timer.SetTimeout(24 * time.Hour)
173+
}
174+
return
157175
}
176+
w.Timer.Update()
158177
for _, b := range mb {
159178
var ctl Control
160179
if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
161180
errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
162-
break
181+
w.Timer.SetTimeout(0)
182+
return
163183
}
164184
if ctl.State != w.State {
165185
w.State = ctl.State

app/reverse/portal.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/xtls/xray-core/common/net"
1313
"github.com/xtls/xray-core/common/serial"
1414
"github.com/xtls/xray-core/common/session"
15+
"github.com/xtls/xray-core/common/signal"
1516
"github.com/xtls/xray-core/common/task"
1617
"github.com/xtls/xray-core/features/outbound"
1718
"github.com/xtls/xray-core/transport"
@@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error {
159160
for _, w := range p.workers {
160161
if !w.Closed() {
161162
activeWorkers = append(activeWorkers, w)
163+
} else {
164+
w.timer.SetTimeout(0)
162165
}
163166
}
164167

@@ -225,6 +228,7 @@ type PortalWorker struct {
225228
reader buf.Reader
226229
draining bool
227230
counter uint32
231+
timer *signal.ActivityTimer
228232
}
229233

230234
func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
@@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
244248
if !f {
245249
return nil, errors.New("unable to dispatch control connection")
246250
}
251+
terminate := func() {
252+
client.Close()
253+
}
247254
w := &PortalWorker{
248255
client: client,
249256
reader: downlinkReader,
250257
writer: uplinkWriter,
258+
timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
251259
}
252260
w.control = &task.Periodic{
253261
Execute: w.heartbeat,
@@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error {
274282
msg.State = Control_DRAIN
275283

276284
defer func() {
277-
w.client.GetTimer().Reset(time.Second * 16)
278285
common.Close(w.writer)
279286
common.Interrupt(w.reader)
280287
w.writer = nil
@@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error {
286293
b, err := proto.Marshal(msg)
287294
common.Must(err)
288295
mb := buf.MergeBytes(nil, b)
296+
w.timer.Update()
289297
return w.writer.WriteMultiBuffer(mb)
290298
}
291299
return nil

common/mux/client.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,23 +219,24 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
219219
return m.done.Wait()
220220
}
221221

222-
func (m *ClientWorker) GetTimer() *time.Ticker {
223-
return m.timer
222+
func (m *ClientWorker) Close() error {
223+
return m.done.Close()
224224
}
225225

226226
func (m *ClientWorker) monitor() {
227227
defer m.timer.Stop()
228228

229229
for {
230+
checkSize := m.sessionManager.Size()
231+
checkCount := m.sessionManager.Count()
230232
select {
231233
case <-m.done.Wait():
232234
m.sessionManager.Close()
233235
common.Interrupt(m.link.Writer)
234236
common.Interrupt(m.link.Reader)
235237
return
236238
case <-m.timer.C:
237-
size := m.sessionManager.Size()
238-
if size == 0 && m.sessionManager.CloseIfNoSession() {
239+
if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
239240
common.Must(m.done.Close())
240241
}
241242
}
@@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
255256
return nil
256257
}
257258

258-
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
259+
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
259260
outbounds := session.OutboundsFromContext(ctx)
260261
ob := outbounds[len(outbounds)-1]
261262
transferType := protocol.TransferTypeStream
@@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
266267
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
267268
defer s.Close(false)
268269
defer writer.Close()
269-
defer timer.Reset(time.Second * 16)
270270

271271
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
272272
if err := writeFirstPayload(s.input, writer); err != nil {
@@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
316316
}
317317
s.input = link.Reader
318318
s.output = link.Writer
319-
if _, ok := link.Reader.(*pipe.Reader); ok {
320-
go fetchInput(ctx, s, m.link.Writer, m.timer)
321-
} else {
322-
fetchInput(ctx, s, m.link.Writer, m.timer)
319+
go fetchInput(ctx, s, m.link.Writer)
320+
if _, ok := link.Reader.(*pipe.Reader); !ok {
321+
select {
322+
case <-ctx.Done():
323+
case <-s.done.Wait():
324+
}
323325
}
324326
return true
325327
}

common/mux/server.go

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mux
33
import (
44
"context"
55
"io"
6+
"time"
67

78
"github.com/xtls/xray-core/app/dispatcher"
89
"github.com/xtls/xray-core/common"
@@ -12,6 +13,7 @@ import (
1213
"github.com/xtls/xray-core/common/net"
1314
"github.com/xtls/xray-core/common/protocol"
1415
"github.com/xtls/xray-core/common/session"
16+
"github.com/xtls/xray-core/common/signal/done"
1517
"github.com/xtls/xray-core/core"
1618
"github.com/xtls/xray-core/features/routing"
1719
"github.com/xtls/xray-core/transport"
@@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
6365
return s.dispatcher.DispatchLink(ctx, dest, link)
6466
}
6567
link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
66-
_, err := NewServerWorker(ctx, s.dispatcher, link)
67-
return err
68+
worker, err := NewServerWorker(ctx, s.dispatcher, link)
69+
if err != nil {
70+
return err
71+
}
72+
select {
73+
case <-ctx.Done():
74+
case <-worker.done.Wait():
75+
}
76+
return nil
6877
}
6978

7079
// Start implements common.Runnable.
@@ -81,22 +90,23 @@ type ServerWorker struct {
8190
dispatcher routing.Dispatcher
8291
link *transport.Link
8392
sessionManager *SessionManager
93+
done *done.Instance
94+
timer *time.Ticker
8495
}
8596

8697
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
8798
worker := &ServerWorker{
8899
dispatcher: d,
89100
link: link,
90101
sessionManager: NewSessionManager(),
102+
done: done.New(),
103+
timer: time.NewTicker(60 * time.Second),
91104
}
92105
if inbound := session.InboundFromContext(ctx); inbound != nil {
93106
inbound.CanSpliceCopy = 3
94107
}
95-
if _, ok := link.Reader.(*pipe.Reader); ok {
96-
go worker.run(ctx)
97-
} else {
98-
worker.run(ctx)
99-
}
108+
go worker.run(ctx)
109+
go worker.monitor()
100110
return worker, nil
101111
}
102112

@@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
111121
s.Close(false)
112122
}
113123

124+
func (w *ServerWorker) monitor() {
125+
defer w.timer.Stop()
126+
127+
for {
128+
checkSize := w.sessionManager.Size()
129+
checkCount := w.sessionManager.Count()
130+
select {
131+
case <-w.done.Wait():
132+
w.sessionManager.Close()
133+
common.Interrupt(w.link.Writer)
134+
common.Interrupt(w.link.Reader)
135+
return
136+
case <-w.timer.C:
137+
if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
138+
common.Must(w.done.Close())
139+
}
140+
}
141+
}
142+
}
143+
114144
func (w *ServerWorker) ActiveConnections() uint32 {
115145
return uint32(w.sessionManager.Size())
116146
}
117147

118148
func (w *ServerWorker) Closed() bool {
119-
return w.sessionManager.Closed()
149+
return w.done.Done()
150+
}
151+
152+
func (w *ServerWorker) WaitClosed() <-chan struct{} {
153+
return w.done.Wait()
154+
}
155+
156+
func (w *ServerWorker) Close() error {
157+
return w.done.Close()
120158
}
121159

122160
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
@@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
317355
}
318356

319357
func (w *ServerWorker) run(ctx context.Context) {
320-
reader := &buf.BufferedReader{Reader: w.link.Reader}
358+
defer func() {
359+
common.Must(w.done.Close())
360+
}()
321361

322-
defer w.sessionManager.Close()
323-
defer common.Interrupt(w.link.Reader)
324-
defer common.Interrupt(w.link.Writer)
362+
reader := &buf.BufferedReader{Reader: w.link.Reader}
325363

326364
for {
327365
select {

common/mux/session.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/xtls/xray-core/common/errors"
1313
"github.com/xtls/xray-core/common/net"
1414
"github.com/xtls/xray-core/common/protocol"
15+
"github.com/xtls/xray-core/common/signal/done"
1516
"github.com/xtls/xray-core/transport/pipe"
1617
)
1718

@@ -53,7 +54,7 @@ func (m *SessionManager) Count() int {
5354
func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
5455
m.Lock()
5556
defer m.Unlock()
56-
57+
5758
MaxConcurrency := int(Strategy.MaxConcurrency)
5859
MaxConnection := uint16(Strategy.MaxConnection)
5960

@@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
6566
s := &Session{
6667
ID: m.count,
6768
parent: m,
69+
done: done.New(),
6870
}
6971
m.sessions[s.ID] = s
7072
return s
@@ -115,19 +117,21 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
115117
return s, found
116118
}
117119

118-
func (m *SessionManager) CloseIfNoSession() bool {
120+
func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
119121
m.Lock()
120122
defer m.Unlock()
121123

122124
if m.closed {
123125
return true
124126
}
125127

126-
if len(m.sessions) != 0 {
128+
if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
127129
return false
128130
}
129131

130132
m.closed = true
133+
134+
m.sessions = nil
131135
return true
132136
}
133137

@@ -157,6 +161,7 @@ type Session struct {
157161
ID uint16
158162
transferType protocol.TransferType
159163
closed bool
164+
done *done.Instance
160165
XUDP *XUDP
161166
}
162167

@@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
171176
return nil
172177
}
173178
s.closed = true
179+
if s.done != nil {
180+
s.done.Close()
181+
}
174182
if s.XUDP == nil {
175183
common.Interrupt(s.input)
176184
common.Close(s.output)

common/mux/session_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
4141
m := NewSessionManager()
4242
s := m.Allocate(&ClientStrategy{})
4343

44-
if m.CloseIfNoSession() {
44+
if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
4545
t.Error("able to close")
4646
}
4747
m.Remove(false, s.ID)
48-
if !m.CloseIfNoSession() {
48+
if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
4949
t.Error("not able to close")
5050
}
5151
}

0 commit comments

Comments
 (0)