Skip to content

Commit 63a29f4

Browse files
chungthuangnmldiegues
authored andcommitted
TUN-3895: Tests for socks stream handler
1 parent e20c4f8 commit 63a29f4

File tree

5 files changed

+131
-22
lines changed

5 files changed

+131
-22
lines changed

carrier/carrier_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (s *testStreamer) Write(p []byte) (int, error) {
4444
func TestStartClient(t *testing.T) {
4545
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
4646
log := zerolog.Nop()
47-
wsConn := NewWSConnection(&log, false)
47+
wsConn := NewWSConnection(&log)
4848
ts := newTestWebSocketServer()
4949
defer ts.Close()
5050

@@ -70,7 +70,7 @@ func TestStartServer(t *testing.T) {
7070
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
7171
log := zerolog.Nop()
7272
shutdownC := make(chan struct{})
73-
wsConn := NewWSConnection(&log, false)
73+
wsConn := NewWSConnection(&log)
7474
ts := newTestWebSocketServer()
7575
defer ts.Close()
7676
options := &StartOptions{

carrier/websocket.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, er
3838
}
3939

4040
// NewWSConnection returns a new connection object
41-
func NewWSConnection(log *zerolog.Logger, isSocks bool) Connection {
41+
func NewWSConnection(log *zerolog.Logger) Connection {
4242
return &Websocket{
43-
log: log,
44-
isSocks: isSocks,
43+
log: log,
4544
}
4645
}
4746

@@ -55,15 +54,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
5554
}
5655
defer wsConn.Close()
5756

58-
if ws.isSocks {
59-
dialer := &wsdialer{conn: wsConn}
60-
requestHandler := socks.NewRequestHandler(dialer)
61-
socksServer := socks.NewConnectionHandler(requestHandler)
62-
63-
_ = socksServer.Serve(conn)
64-
} else {
65-
ingress.Stream(wsConn, conn, ws.log)
66-
}
57+
ingress.Stream(wsConn, conn, ws.log)
6758
return nil
6859
}
6960

cmd/cloudflared/access/carrier.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z
4848
}
4949

5050
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
51-
wsConn := carrier.NewWSConnection(log, false)
51+
wsConn := carrier.NewWSConnection(log)
5252

5353
log.Info().Str(LogFieldHost, validURL.Host).Msg("Start Websocket listener")
5454
return carrier.StartForwarder(wsConn, validURL.Host, shutdown, options)
@@ -100,7 +100,7 @@ func ssh(c *cli.Context) error {
100100
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
101101
options.TLSClientConfig = &tls.Config{
102102
InsecureSkipVerify: true,
103-
ServerName: parts[0],
103+
ServerName: parts[0],
104104
}
105105
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
106106
default:
@@ -109,15 +109,14 @@ func ssh(c *cli.Context) error {
109109
}
110110

111111
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
112-
wsConn := carrier.NewWSConnection(log, false)
112+
wsConn := carrier.NewWSConnection(log)
113113

114114
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
115115
forwarder, err := config.ValidateUrl(c, true)
116116
if err != nil {
117117
log.Err(err).Msg("Error validating origin URL")
118118
return errors.Wrap(err, "error validating origin URL")
119119
}
120-
121120
log.Info().Str(LogFieldHost, forwarder.Host).Msg("Start Websocket listener")
122121
err = carrier.StartForwarder(wsConn, forwarder.Host, shutdownC, options)
123122
if err != nil {

ingress/origin_connection_test.go

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
package ingress
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/tls"
67
"fmt"
8+
"io/ioutil"
79
"net"
810
"net/http"
911
"net/http/httptest"
12+
"net/url"
1013
"testing"
1114
"time"
1215

1316
"github.com/cloudflare/cloudflared/logger"
17+
"github.com/cloudflare/cloudflared/socks"
1418
"github.com/gobwas/ws/wsutil"
15-
"github.com/gorilla/websocket"
19+
gorillaWS "github.com/gorilla/websocket"
1620
"github.com/stretchr/testify/assert"
1721
"github.com/stretchr/testify/require"
22+
"golang.org/x/net/proxy"
1823
"golang.org/x/sync/errgroup"
1924
)
2025

2126
const (
2227
testStreamTimeout = time.Second * 3
28+
echoHeaderName = "Test-Cloudflared-Echo"
2329
)
2430

2531
var (
@@ -61,7 +67,7 @@ func TestStreamTCPConnection(t *testing.T) {
6167
require.NoError(t, errGroup.Wait())
6268
}
6369

64-
func TestStreamWSOverTCPConnection(t *testing.T) {
70+
func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
6571
cfdConn, originConn := net.Pipe()
6672
tcpOverWSConn := tcpOverWSConnection{
6773
conn: cfdConn,
@@ -88,6 +94,100 @@ func TestStreamWSOverTCPConnection(t *testing.T) {
8894
require.NoError(t, errGroup.Wait())
8995
}
9096

97+
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
98+
// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
99+
// wraps SOCKS5 traffic in websocket
100+
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
101+
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
102+
var (
103+
sendMessage = t.Name()
104+
echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
105+
echoMessage = fmt.Sprintf("echo-%s", sendMessage)
106+
echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
107+
)
108+
109+
statusCodes := []int{
110+
http.StatusOK,
111+
http.StatusTemporaryRedirect,
112+
http.StatusBadRequest,
113+
http.StatusInternalServerError,
114+
}
115+
for _, status := range statusCodes {
116+
handler := func(w http.ResponseWriter, r *http.Request) {
117+
body, err := ioutil.ReadAll(r.Body)
118+
require.NoError(t, err)
119+
require.Equal(t, []byte(sendMessage), body)
120+
121+
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
122+
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
123+
124+
w.WriteHeader(status)
125+
w.Write([]byte(echoMessage))
126+
}
127+
origin := httptest.NewServer(http.HandlerFunc(handler))
128+
defer origin.Close()
129+
130+
originURL, err := url.Parse(origin.URL)
131+
require.NoError(t, err)
132+
133+
originConn, err := net.Dial("tcp", originURL.Host)
134+
require.NoError(t, err)
135+
136+
tcpOverWSConn := tcpOverWSConnection{
137+
conn: originConn,
138+
streamHandler: socks.StreamHandler,
139+
}
140+
141+
wsForwarderOutConn, edgeConn := net.Pipe()
142+
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
143+
defer cancel()
144+
145+
errGroup, ctx := errgroup.WithContext(ctx)
146+
errGroup.Go(func() error {
147+
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
148+
return nil
149+
})
150+
151+
wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
152+
require.NoError(t, err)
153+
154+
errGroup.Go(func() error {
155+
wsForwarderInConn, err := wsForwarderListener.Accept()
156+
require.NoError(t, err)
157+
defer wsForwarderInConn.Close()
158+
159+
Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
160+
return nil
161+
})
162+
163+
eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
164+
require.NoError(t, err)
165+
166+
transport := &http.Transport{
167+
Dial: eyeballDialer.Dial,
168+
}
169+
170+
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
171+
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
172+
assert.NoError(t, err)
173+
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
174+
175+
resp, err := transport.RoundTrip(req)
176+
assert.NoError(t, err)
177+
assert.Equal(t, status, resp.StatusCode)
178+
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
179+
body, err := ioutil.ReadAll(resp.Body)
180+
require.NoError(t, err)
181+
require.Equal(t, []byte(echoMessage), body)
182+
183+
wsForwarderOutConn.Close()
184+
edgeConn.Close()
185+
tcpOverWSConn.Close()
186+
187+
require.NoError(t, errGroup.Wait())
188+
}
189+
}
190+
91191
func TestStreamWSConnection(t *testing.T) {
92192
eyeballConn, edgeConn := net.Pipe()
93193

@@ -121,6 +221,23 @@ func TestStreamWSConnection(t *testing.T) {
121221
require.NoError(t, errGroup.Wait())
122222
}
123223

224+
type wsEyeball struct {
225+
conn net.Conn
226+
}
227+
228+
func (wse *wsEyeball) Read(p []byte) (int, error) {
229+
data, err := wsutil.ReadServerBinary(wse.conn)
230+
if err != nil {
231+
return 0, err
232+
}
233+
return copy(p, data), nil
234+
}
235+
236+
func (wse *wsEyeball) Write(p []byte) (int, error) {
237+
err := wsutil.WriteClientBinary(wse.conn, p)
238+
return len(p), err
239+
}
240+
124241
func echoWSEyeball(t *testing.T, conn net.Conn) {
125242
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
126243

@@ -133,7 +250,7 @@ func echoWSEyeball(t *testing.T, conn net.Conn) {
133250
}
134251

135252
func echoWSOrigin(t *testing.T) *httptest.Server {
136-
var upgrader = websocket.Upgrader{
253+
var upgrader = gorillaWS.Upgrader{
137254
ReadBufferSize: 10,
138255
WriteBufferSize: 10,
139256
}

socks/request_handler.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,7 @@ func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.L
113113
requestHandler := NewRequestHandler(dialer)
114114
socksServer := NewConnectionHandler(requestHandler)
115115

116-
socksServer.Serve(tunnelConn)
116+
if err := socksServer.Serve(tunnelConn); err != nil {
117+
log.Debug().Err(err).Msg("Socks stream handler error")
118+
}
117119
}

0 commit comments

Comments
 (0)