11package ingress
22
33import (
4+ "bytes"
45 "context"
56 "crypto/tls"
67 "fmt"
8+ "io/ioutil"
79 "net"
810 "net/http"
911 "net/http/httptest"
12+ "net/url"
1013 "testing"
1114 "time"
1215
1316 "github.com/cloudflare/cloudflared/logger"
17+ "github.com/cloudflare/cloudflared/socks"
1418 "github.com/gobwas/ws/wsutil"
15- "github.com/gorilla/websocket"
19+ gorillaWS "github.com/gorilla/websocket"
1620 "github.com/stretchr/testify/assert"
1721 "github.com/stretchr/testify/require"
22+ "golang.org/x/net/proxy"
1823 "golang.org/x/sync/errgroup"
1924)
2025
2126const (
2227 testStreamTimeout = time .Second * 3
28+ echoHeaderName = "Test-Cloudflared-Echo"
2329)
2430
2531var (
@@ -61,7 +67,7 @@ func TestStreamTCPConnection(t *testing.T) {
6167 require .NoError (t , errGroup .Wait ())
6268}
6369
64- func TestStreamWSOverTCPConnection (t * testing.T ) {
70+ func TestDefaultStreamWSOverTCPConnection (t * testing.T ) {
6571 cfdConn , originConn := net .Pipe ()
6672 tcpOverWSConn := tcpOverWSConnection {
6773 conn : cfdConn ,
@@ -88,6 +94,100 @@ func TestStreamWSOverTCPConnection(t *testing.T) {
8894 require .NoError (t , errGroup .Wait ())
8995}
9096
97+ // TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
98+ // Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
99+ // wraps SOCKS5 traffic in websocket
100+ // Origin side runs a tcpOverWSConnection with socks.StreamHandler
101+ func TestSocksStreamWSOverTCPConnection (t * testing.T ) {
102+ var (
103+ sendMessage = t .Name ()
104+ echoHeaderIncomingValue = fmt .Sprintf ("header-%s" , sendMessage )
105+ echoMessage = fmt .Sprintf ("echo-%s" , sendMessage )
106+ echoHeaderReturnValue = fmt .Sprintf ("echo-%s" , echoHeaderIncomingValue )
107+ )
108+
109+ statusCodes := []int {
110+ http .StatusOK ,
111+ http .StatusTemporaryRedirect ,
112+ http .StatusBadRequest ,
113+ http .StatusInternalServerError ,
114+ }
115+ for _ , status := range statusCodes {
116+ handler := func (w http.ResponseWriter , r * http.Request ) {
117+ body , err := ioutil .ReadAll (r .Body )
118+ require .NoError (t , err )
119+ require .Equal (t , []byte (sendMessage ), body )
120+
121+ require .Equal (t , echoHeaderIncomingValue , r .Header .Get (echoHeaderName ))
122+ w .Header ().Set (echoHeaderName , echoHeaderReturnValue )
123+
124+ w .WriteHeader (status )
125+ w .Write ([]byte (echoMessage ))
126+ }
127+ origin := httptest .NewServer (http .HandlerFunc (handler ))
128+ defer origin .Close ()
129+
130+ originURL , err := url .Parse (origin .URL )
131+ require .NoError (t , err )
132+
133+ originConn , err := net .Dial ("tcp" , originURL .Host )
134+ require .NoError (t , err )
135+
136+ tcpOverWSConn := tcpOverWSConnection {
137+ conn : originConn ,
138+ streamHandler : socks .StreamHandler ,
139+ }
140+
141+ wsForwarderOutConn , edgeConn := net .Pipe ()
142+ ctx , cancel := context .WithTimeout (context .Background (), testStreamTimeout )
143+ defer cancel ()
144+
145+ errGroup , ctx := errgroup .WithContext (ctx )
146+ errGroup .Go (func () error {
147+ tcpOverWSConn .Stream (ctx , edgeConn , testLogger )
148+ return nil
149+ })
150+
151+ wsForwarderListener , err := net .Listen ("tcp" , "127.0.0.1:0" )
152+ require .NoError (t , err )
153+
154+ errGroup .Go (func () error {
155+ wsForwarderInConn , err := wsForwarderListener .Accept ()
156+ require .NoError (t , err )
157+ defer wsForwarderInConn .Close ()
158+
159+ Stream (wsForwarderInConn , & wsEyeball {wsForwarderOutConn }, testLogger )
160+ return nil
161+ })
162+
163+ eyeballDialer , err := proxy .SOCKS5 ("tcp" , wsForwarderListener .Addr ().String (), nil , proxy .Direct )
164+ require .NoError (t , err )
165+
166+ transport := & http.Transport {
167+ Dial : eyeballDialer .Dial ,
168+ }
169+
170+ // Request URL doesn't matter because the transport is using eyeballDialer to connectq
171+ req , err := http .NewRequestWithContext (ctx , "GET" , "http://test-socks-stream.com" , bytes .NewBuffer ([]byte (sendMessage )))
172+ assert .NoError (t , err )
173+ req .Header .Set (echoHeaderName , echoHeaderIncomingValue )
174+
175+ resp , err := transport .RoundTrip (req )
176+ assert .NoError (t , err )
177+ assert .Equal (t , status , resp .StatusCode )
178+ require .Equal (t , echoHeaderReturnValue , resp .Header .Get (echoHeaderName ))
179+ body , err := ioutil .ReadAll (resp .Body )
180+ require .NoError (t , err )
181+ require .Equal (t , []byte (echoMessage ), body )
182+
183+ wsForwarderOutConn .Close ()
184+ edgeConn .Close ()
185+ tcpOverWSConn .Close ()
186+
187+ require .NoError (t , errGroup .Wait ())
188+ }
189+ }
190+
91191func TestStreamWSConnection (t * testing.T ) {
92192 eyeballConn , edgeConn := net .Pipe ()
93193
@@ -121,6 +221,23 @@ func TestStreamWSConnection(t *testing.T) {
121221 require .NoError (t , errGroup .Wait ())
122222}
123223
224+ type wsEyeball struct {
225+ conn net.Conn
226+ }
227+
228+ func (wse * wsEyeball ) Read (p []byte ) (int , error ) {
229+ data , err := wsutil .ReadServerBinary (wse .conn )
230+ if err != nil {
231+ return 0 , err
232+ }
233+ return copy (p , data ), nil
234+ }
235+
236+ func (wse * wsEyeball ) Write (p []byte ) (int , error ) {
237+ err := wsutil .WriteClientBinary (wse .conn , p )
238+ return len (p ), err
239+ }
240+
124241func echoWSEyeball (t * testing.T , conn net.Conn ) {
125242 require .NoError (t , wsutil .WriteClientBinary (conn , testMessage ))
126243
@@ -133,7 +250,7 @@ func echoWSEyeball(t *testing.T, conn net.Conn) {
133250}
134251
135252func echoWSOrigin (t * testing.T ) * httptest.Server {
136- var upgrader = websocket .Upgrader {
253+ var upgrader = gorillaWS .Upgrader {
137254 ReadBufferSize : 10 ,
138255 WriteBufferSize : 10 ,
139256 }
0 commit comments