@@ -95,6 +95,9 @@ func (s *WsTransport) Restart() {
9595 s .cancel ()
9696 }
9797
98+ // Close any open connections in the tunnel channel.
99+ go s .cleanupConnections ()
100+
98101 time .Sleep (2 * time .Second )
99102
100103 ctx , cancel := context .WithCancel (s .parentctx )
@@ -111,6 +114,61 @@ func (s *WsTransport) Restart() {
111114 go s .TunnelListener ()
112115
113116}
117+
118+ // cleanupConnections closes all active connections in the tunnel channel.
119+ func (s * WsTransport ) cleanupConnections () {
120+ if s .controlChannel != nil {
121+ s .logger .Debug ("control channel have been closed." )
122+ s .controlChannel .Close ()
123+ }
124+ for {
125+ select {
126+ case conn := <- s .tunnelChannel :
127+ if conn .conn != nil {
128+ conn .conn .Close ()
129+ s .logger .Trace ("existing tunnel connections have been closed." )
130+ }
131+ default :
132+ return
133+ }
134+ }
135+ }
136+
137+ func (s * WsTransport ) getClosedSignal () {
138+ for {
139+ // Channel to receive the message or error
140+ resultChan := make (chan struct {
141+ message []byte
142+ err error
143+ })
144+ go func () {
145+ _ , message , err := s .controlChannel .ReadMessage ()
146+ resultChan <- struct {
147+ message []byte
148+ err error
149+ }{message , err }
150+ }()
151+
152+ select {
153+ case <- s .ctx .Done ():
154+ return
155+
156+ case result := <- resultChan :
157+ if result .err != nil {
158+ s .logger .Errorf ("failed to receive message from tunnel connection: %v" , result .err )
159+ go s .Restart ()
160+ return
161+ }
162+ if string (result .message ) == "closed" {
163+ s .logger .Info ("control channel has been closed by the client" )
164+ go s .Restart ()
165+ return
166+ }
167+ }
168+ }
169+
170+ }
171+
114172func (s * WsTransport ) portConfigReader () {
115173 // port mapping for listening on each local port
116174 for _ , portMapping := range s .config .Ports {
@@ -260,6 +318,7 @@ func (s *WsTransport) TunnelListener() {
260318 go s .heartbeat ()
261319 go s .poolChecker ()
262320 go s .portConfigReader ()
321+ go s .getClosedSignal ()
263322
264323 s .config .TunnelStatus = "Connected (Websocket)"
265324
0 commit comments