Skip to content

Commit 1b7e2bb

Browse files
committed
[feat] new transport: wsmux, wssmux
1 parent 6379c7e commit 1b7e2bb

File tree

9 files changed

+713
-31
lines changed

9 files changed

+713
-31
lines changed

cmd/defaults.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
)
88

99
const ( // Default values
10-
defaultTransport = config.TCP
1110
defaultToken = "musix"
1211
defaultChannelSize = 2048
1312
defaultRetryInterval = 1 // only for client
@@ -25,25 +24,6 @@ const ( // Default values
2524
)
2625

2726
func applyDefaults(cfg *config.Config) {
28-
// Transport
29-
switch cfg.Server.Transport {
30-
case config.TCP, config.TCPMUX, config.WS, config.WSS: // valid values
31-
case "":
32-
cfg.Server.Transport = defaultTransport
33-
default:
34-
logger.Warnf("invalid transport value '%s' for server, defaulting to '%s'", cfg.Server.Transport, defaultTransport)
35-
cfg.Server.Transport = defaultTransport
36-
}
37-
38-
switch cfg.Client.Transport {
39-
case config.TCP, config.TCPMUX, config.WS, config.WSS: //valid values
40-
case "":
41-
cfg.Client.Transport = defaultTransport
42-
default:
43-
logger.Warnf("invalid transport value '%s' for client, defaulting to '%s'", cfg.Client.Transport, defaultTransport)
44-
cfg.Client.Transport = defaultTransport
45-
}
46-
4727
// Token
4828
if cfg.Server.Token == "" {
4929
cfg.Server.Token = defaultToken

internal/client/client.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,29 @@ func (c *Client) Start() {
9696
}
9797
WsClient := transport.NewWSClient(c.ctx, WsConfig, c.logger)
9898
go WsClient.ChannelDialer()
99+
} else if c.config.Transport == config.WSMUX || c.config.Transport == config.WSSMUX {
100+
wsMuxConfig := &transport.WsMuxConfig{
101+
RemoteAddr: c.config.RemoteAddr,
102+
Nodelay: c.config.Nodelay,
103+
KeepAlive: time.Duration(c.config.Keepalive) * time.Second,
104+
RetryInterval: time.Duration(c.config.RetryInterval) * time.Second,
105+
Token: c.config.Token,
106+
MuxSession: c.config.MuxSession,
107+
MuxVersion: c.config.MuxVersion,
108+
MaxFrameSize: c.config.MaxFrameSize,
109+
MaxReceiveBuffer: c.config.MaxReceiveBuffer,
110+
MaxStreamBuffer: c.config.MaxStreamBuffer,
111+
Forwarder: c.forwarderReader(c.config.Forwarder),
112+
Sniffer: c.config.Sniffer,
113+
WebPort: c.config.WebPort,
114+
SnifferLog: c.config.SnifferLog,
115+
Mode: c.config.Transport,
116+
}
117+
wsMuxClient := transport.NewWSMuxClient(c.ctx, wsMuxConfig, c.logger)
118+
go wsMuxClient.MuxDialer()
119+
120+
} else {
121+
c.logger.Fatal("invalid transport type: ", c.config.Transport)
99122
}
100123

101124
<-c.ctx.Done()

internal/client/transport/tcpmux.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func NewMuxClient(parentCtx context.Context, config *TcpMuxConfig, logger *logru
5454
cancel: cancel,
5555
logger: logger,
5656
smuxSession: make([]*smux.Session, config.MuxSession),
57-
timeout: 5 * time.Second, // Default timeout
57+
timeout: 10 * time.Second, // Default timeout
5858
usageMonitor: web.NewDataStore(fmt.Sprintf(":%v", config.WebPort), ctx, config.SnifferLog, config.Sniffer, &config.TunnelStatus, logger),
5959
}
6060

internal/client/transport/wsmux.go

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
package transport
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"sync"
10+
"time"
11+
12+
"github.com/gorilla/websocket"
13+
"github.com/musix/backhaul/internal/config"
14+
"github.com/musix/backhaul/internal/utils"
15+
"github.com/musix/backhaul/internal/web"
16+
17+
"github.com/sirupsen/logrus"
18+
"github.com/xtaci/smux"
19+
)
20+
21+
type WsMuxTransport struct {
22+
config *WsMuxConfig
23+
ctx context.Context
24+
cancel context.CancelFunc
25+
logger *logrus.Logger
26+
smuxSession []*smux.Session
27+
restartMutex sync.Mutex
28+
timeout time.Duration
29+
usageMonitor *web.Usage
30+
}
31+
32+
type WsMuxConfig struct {
33+
RemoteAddr string
34+
Nodelay bool
35+
KeepAlive time.Duration
36+
RetryInterval time.Duration
37+
Token string
38+
MuxSession int
39+
Forwarder map[int]string
40+
MuxVersion int
41+
MaxFrameSize int
42+
MaxReceiveBuffer int
43+
MaxStreamBuffer int
44+
Sniffer bool
45+
WebPort int
46+
SnifferLog string
47+
TunnelStatus string
48+
Mode config.TransportType
49+
}
50+
51+
func NewWSMuxClient(parentCtx context.Context, config *WsMuxConfig, logger *logrus.Logger) *WsMuxTransport {
52+
// Create a derived context from the parent context
53+
ctx, cancel := context.WithCancel(parentCtx)
54+
55+
// Initialize the TcpTransport struct
56+
client := &WsMuxTransport{
57+
config: config,
58+
ctx: ctx,
59+
cancel: cancel,
60+
logger: logger,
61+
smuxSession: make([]*smux.Session, config.MuxSession),
62+
timeout: 10 * time.Second, // Default timeout
63+
usageMonitor: web.NewDataStore(fmt.Sprintf(":%v", config.WebPort), ctx, config.SnifferLog, config.Sniffer, &config.TunnelStatus, logger),
64+
}
65+
66+
return client
67+
}
68+
69+
func (c *WsMuxTransport) Restart() {
70+
if !c.restartMutex.TryLock() {
71+
c.logger.Warn("client is already restarting")
72+
return
73+
}
74+
defer c.restartMutex.Unlock()
75+
76+
c.logger.Info("restarting client...")
77+
if c.cancel != nil {
78+
c.cancel()
79+
}
80+
81+
time.Sleep(2 * time.Second)
82+
83+
ctx, cancel := context.WithCancel(context.Background())
84+
c.ctx = ctx
85+
c.cancel = cancel
86+
87+
// Re-initialize variables
88+
c.smuxSession = make([]*smux.Session, c.config.MuxSession)
89+
c.usageMonitor = web.NewDataStore(fmt.Sprintf(":%v", c.config.WebPort), ctx, c.config.SnifferLog, c.config.Sniffer, &c.config.TunnelStatus, c.logger)
90+
c.config.TunnelStatus = ""
91+
92+
go c.MuxDialer()
93+
94+
}
95+
96+
func (c *WsMuxTransport) MuxDialer() {
97+
// for webui
98+
if c.config.WebPort > 0 {
99+
go c.usageMonitor.Monitor()
100+
}
101+
102+
c.config.TunnelStatus = "Disconnected (WSMux)"
103+
104+
for id := 0; id < c.config.MuxSession; id++ {
105+
innerloop:
106+
for {
107+
select {
108+
case <-c.ctx.Done():
109+
return
110+
default:
111+
c.logger.Debugf("initiating new mux session to address %s (session ID: %d)", c.config.RemoteAddr, id)
112+
// Dial to the tunnel server
113+
tunnelTCPConn, err := c.wsDialer(c.config.RemoteAddr, "/channel")
114+
if err != nil {
115+
c.logger.Errorf("failed to dial tunnel server at %s: %v", c.config.RemoteAddr, err)
116+
time.Sleep(c.config.RetryInterval)
117+
continue
118+
}
119+
120+
// config fot smux
121+
config := smux.Config{
122+
Version: c.config.MuxVersion, // Smux protocol version
123+
KeepAliveInterval: 10 * time.Second, // Shorter keep-alive interval to quickly detect dead peers
124+
KeepAliveTimeout: 30 * time.Second, // Aggressive timeout to handle unresponsive connections
125+
MaxFrameSize: c.config.MaxFrameSize,
126+
MaxReceiveBuffer: c.config.MaxReceiveBuffer,
127+
MaxStreamBuffer: c.config.MaxStreamBuffer,
128+
}
129+
130+
// SMUX server
131+
session, err := smux.Server(tunnelTCPConn.UnderlyingConn(), &config)
132+
if err != nil {
133+
c.logger.Errorf("failed to create mux session: %v", err)
134+
continue
135+
}
136+
137+
c.smuxSession[id] = session
138+
c.logger.Infof("mux session established successfully (session ID: %d)", id)
139+
go c.handleMUXStreams(id)
140+
break innerloop
141+
}
142+
}
143+
}
144+
c.config.TunnelStatus = "Connected (WSMux)"
145+
}
146+
147+
func (c *WsMuxTransport) handleMUXStreams(id int) {
148+
for {
149+
select {
150+
case <-c.ctx.Done():
151+
return
152+
default:
153+
stream, err := c.smuxSession[id].AcceptStream()
154+
if err != nil {
155+
c.logger.Errorf("failed to accept mux stream for session ID %d: %v", id, err)
156+
c.logger.Info("attempting to restart client...")
157+
go c.Restart()
158+
return
159+
160+
}
161+
go c.handleTCPSession(stream)
162+
}
163+
}
164+
}
165+
166+
func (c *WsMuxTransport) handleTCPSession(tcpsession net.Conn) {
167+
select {
168+
case <-c.ctx.Done():
169+
return
170+
default:
171+
port, err := utils.ReceiveBinaryInt(tcpsession)
172+
173+
if err != nil {
174+
c.logger.Tracef("unable to get the port from the %s connection: %v", tcpsession.RemoteAddr().String(), err)
175+
tcpsession.Close()
176+
return
177+
}
178+
go c.localDialer(tcpsession, port)
179+
180+
}
181+
}
182+
183+
func (c *WsMuxTransport) localDialer(tunnelConnection net.Conn, port uint16) {
184+
select {
185+
case <-c.ctx.Done():
186+
return
187+
default:
188+
localAddress, ok := c.config.Forwarder[int(port)]
189+
if !ok {
190+
localAddress = fmt.Sprintf("127.0.0.1:%d", port)
191+
}
192+
193+
localConnection, err := c.tcpDialer(localAddress, c.config.Nodelay)
194+
if err != nil {
195+
c.logger.Errorf("failed to connect to local address %s: %v", localAddress, err)
196+
tunnelConnection.Close()
197+
return
198+
}
199+
c.logger.Debugf("connected to local address %s successfully", localAddress)
200+
go utils.ConnectionHandler(localConnection, tunnelConnection, c.logger, c.usageMonitor, int(port), c.config.Sniffer)
201+
}
202+
}
203+
204+
func (c *WsMuxTransport) tcpDialer(address string, tcpnodelay bool) (*net.TCPConn, error) {
205+
// Resolve the address to a TCP address
206+
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
207+
if err != nil {
208+
return nil, err
209+
}
210+
211+
// options
212+
dialer := &net.Dialer{
213+
Timeout: c.timeout, // Set the connection timeout
214+
KeepAlive: c.config.KeepAlive, // Set the keep-alive duration
215+
}
216+
217+
// Dial the TCP connection with a timeout
218+
conn, err := dialer.Dial("tcp", tcpAddr.String())
219+
if err != nil {
220+
return nil, err
221+
}
222+
223+
// Type assert the net.Conn to *net.TCPConn
224+
tcpConn, ok := conn.(*net.TCPConn)
225+
if !ok {
226+
conn.Close()
227+
return nil, fmt.Errorf("failed to convert net.Conn to *net.TCPConn")
228+
}
229+
230+
if tcpnodelay {
231+
// Enable TCP_NODELAY
232+
err = tcpConn.SetNoDelay(true)
233+
if err != nil {
234+
tcpConn.Close()
235+
return nil, err
236+
}
237+
}
238+
239+
return tcpConn, nil
240+
}
241+
242+
func (c *WsMuxTransport) wsDialer(addr string, path string) (*websocket.Conn, error) {
243+
// Create a TLS configuration that allows insecure connections
244+
tlsConfig := &tls.Config{
245+
InsecureSkipVerify: true, // Skip server certificate verification
246+
}
247+
248+
// Setup headers with authorization
249+
headers := http.Header{}
250+
headers.Add("Authorization", fmt.Sprintf("Bearer %v", c.config.Token))
251+
252+
var wsURL string
253+
dialer := websocket.Dialer{}
254+
if c.config.Mode == config.WSMUX {
255+
wsURL = fmt.Sprintf("ws://%s%s", addr, path)
256+
dialer = websocket.Dialer{
257+
HandshakeTimeout: c.timeout, // Set handshake timeout
258+
NetDial: func(_, addr string) (net.Conn, error) {
259+
conn, err := net.DialTimeout("tcp", addr, c.timeout)
260+
if err != nil {
261+
return nil, err
262+
}
263+
tcpConn := conn.(*net.TCPConn)
264+
tcpConn.SetKeepAlive(true) // Enable TCP keepalive
265+
tcpConn.SetKeepAlivePeriod(c.config.KeepAlive) // Set keepalive period
266+
return tcpConn, nil
267+
},
268+
}
269+
} else {
270+
wsURL = fmt.Sprintf("wss://%s%s", addr, path)
271+
dialer = websocket.Dialer{
272+
TLSClientConfig: tlsConfig, // Pass the insecure TLS config here
273+
HandshakeTimeout: c.timeout, // Set handshake timeout
274+
NetDial: func(_, addr string) (net.Conn, error) {
275+
conn, err := net.DialTimeout("tcp", addr, c.timeout)
276+
if err != nil {
277+
return nil, err
278+
}
279+
tcpConn := conn.(*net.TCPConn)
280+
tcpConn.SetKeepAlive(true) // Enable TCP keepalive
281+
tcpConn.SetKeepAlivePeriod(c.config.KeepAlive) // Set keepalive period
282+
return tcpConn, nil
283+
},
284+
}
285+
}
286+
287+
// Dial to the WebSocket server
288+
tunnelWSConn, _, err := dialer.Dial(wsURL, headers)
289+
if err != nil {
290+
return nil, err
291+
}
292+
293+
return tunnelWSConn, nil
294+
}

internal/config/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ const (
88
TCPMUX TransportType = "tcpmux"
99
WS TransportType = "ws"
1010
WSS TransportType = "wss"
11+
WSMUX TransportType = "wsmux"
12+
WSSMUX TransportType = "wssmux"
1113
)
1214

1315
// ServerConfig represents the configuration for the server.

0 commit comments

Comments
 (0)