Skip to content

Commit 4045b31

Browse files
committed
[fix] ws and tcp channel stability
1 parent 2027384 commit 4045b31

File tree

4 files changed

+77
-25
lines changed

4 files changed

+77
-25
lines changed

internal/client/transport/tcp.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ connectLoop:
102102
for {
103103
select {
104104
case <-c.ctx.Done():
105-
go c.closeControlChannel("context cancellation")
106105
return
107106
default:
108107
tunnelTCPConn, err := c.tcpDialer(c.config.RemoteAddr, c.config.Nodelay)

internal/client/transport/ws.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ func (c *WsTransport) Restart() {
7878
c.cancel()
7979
}
8080

81+
go c.closeControlChannel("restarting client")
82+
8183
time.Sleep(2 * time.Second)
8284

8385
ctx, cancel := context.WithCancel(c.parentctx)
@@ -93,6 +95,14 @@ func (c *WsTransport) Restart() {
9395

9496
}
9597

98+
func (c *WsTransport) closeControlChannel(reason string) {
99+
if c.controlChannel != nil {
100+
_ = c.controlChannel.WriteMessage(websocket.TextMessage, []byte("closed"))
101+
c.controlChannel.Close()
102+
c.logger.Debugf("control channel closed due to %s", reason)
103+
}
104+
}
105+
96106
func (c *WsTransport) ChannelDialer() {
97107
// for webui
98108
if c.config.WebPort > 0 {
@@ -101,6 +111,7 @@ func (c *WsTransport) ChannelDialer() {
101111

102112
c.config.TunnelStatus = "Disconnected (Websocket)"
103113

114+
connectLoop:
104115
for {
105116
select {
106117
case <-c.ctx.Done():
@@ -121,9 +132,12 @@ func (c *WsTransport) ChannelDialer() {
121132

122133
go c.channelListener()
123134

124-
return
135+
break connectLoop
125136
}
126137
}
138+
139+
<-c.ctx.Done()
140+
go c.closeControlChannel("context cancellation")
127141
}
128142

129143
func (c *WsTransport) channelListener() {

internal/server/transport/tcp.go

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewTCPServer(parentCtx context.Context, config *TcpConfig, logger *logrus.L
6161
tunnelChannel: make(chan net.Conn, config.ChannelSize),
6262
getNewConnChan: make(chan struct{}, config.ChannelSize),
6363
controlChannel: nil, // will be set when a control connection is established
64-
timeout: 3 * time.Second, // Default timeout for waiting for a tunnel connection
64+
timeout: 5 * time.Second, // Default timeout for waiting for a tunnel connection
6565
heartbeatDuration: time.Duration(config.Heartbeat) * time.Second, // Heartbeat duration
6666
heartbeatSig: "0", // Default heartbeat signal
6767
chanSignal: "1", // Default channel signal
@@ -498,32 +498,12 @@ func (s *TcpTransport) handleTCPSession(remotePort int, acceptChan chan net.Conn
498498
return
499499

500500
case <-s.ctx.Done():
501-
for {
502-
select {
503-
case conn := <-acceptChan:
504-
if conn != nil {
505-
conn.Close()
506-
s.logger.Trace("existing local connections have been closed.")
507-
}
508-
default:
509-
return
510-
}
511-
}
501+
return
512502

513503
}
514504
}
515505
case <-s.ctx.Done():
516-
for {
517-
select {
518-
case conn := <-acceptChan:
519-
if conn != nil {
520-
conn.Close()
521-
s.logger.Trace("existing local connections have been closed.")
522-
}
523-
default:
524-
return
525-
}
526-
}
506+
return
527507

528508
}
529509

internal/server/transport/ws.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ func (s *WsTransport) Restart() {
9595
s.cancel()
9696
}
9797

98+
// Close any open connections in the tunnel channel.
99+
go s.cleanupConnections()
100+
98101
time.Sleep(2 * time.Second)
99102

100103
ctx, cancel := context.WithCancel(s.parentctx)
@@ -111,6 +114,61 @@ func (s *WsTransport) Restart() {
111114
go s.TunnelListener()
112115

113116
}
117+
118+
// cleanupConnections closes all active connections in the tunnel channel.
119+
func (s *WsTransport) cleanupConnections() {
120+
if s.controlChannel != nil {
121+
s.logger.Debug("control channel have been closed.")
122+
s.controlChannel.Close()
123+
}
124+
for {
125+
select {
126+
case conn := <-s.tunnelChannel:
127+
if conn.conn != nil {
128+
conn.conn.Close()
129+
s.logger.Trace("existing tunnel connections have been closed.")
130+
}
131+
default:
132+
return
133+
}
134+
}
135+
}
136+
137+
func (s *WsTransport) getClosedSignal() {
138+
for {
139+
// Channel to receive the message or error
140+
resultChan := make(chan struct {
141+
message []byte
142+
err error
143+
})
144+
go func() {
145+
_, message, err := s.controlChannel.ReadMessage()
146+
resultChan <- struct {
147+
message []byte
148+
err error
149+
}{message, err}
150+
}()
151+
152+
select {
153+
case <-s.ctx.Done():
154+
return
155+
156+
case result := <-resultChan:
157+
if result.err != nil {
158+
s.logger.Errorf("failed to receive message from tunnel connection: %v", result.err)
159+
go s.Restart()
160+
return
161+
}
162+
if string(result.message) == "closed" {
163+
s.logger.Info("control channel has been closed by the client")
164+
go s.Restart()
165+
return
166+
}
167+
}
168+
}
169+
170+
}
171+
114172
func (s *WsTransport) portConfigReader() {
115173
// port mapping for listening on each local port
116174
for _, portMapping := range s.config.Ports {
@@ -260,6 +318,7 @@ func (s *WsTransport) TunnelListener() {
260318
go s.heartbeat()
261319
go s.poolChecker()
262320
go s.portConfigReader()
321+
go s.getClosedSignal()
263322

264323
s.config.TunnelStatus = "Connected (Websocket)"
265324

0 commit comments

Comments
 (0)