Skip to content

Commit ab4dda5

Browse files
chungthuangnmldiegues
authored andcommitted
TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService
1 parent 5943808 commit ab4dda5

File tree

10 files changed

+563
-212
lines changed

10 files changed

+563
-212
lines changed

connection/connection.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ func (t Type) String() string {
8787
}
8888

8989
type OriginProxy interface {
90+
// If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter
9091
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
9192
}
9293

ingress/ingress.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ var (
2525
)
2626

2727
const (
28-
ServiceBridge = "bridge service"
2928
ServiceBastion = "bastion"
3029
ServiceWarpRouting = "warp-routing"
3130
)
@@ -98,8 +97,7 @@ type WarpRoutingService struct {
9897
}
9998

10099
func NewWarpRoutingService() *WarpRoutingService {
101-
warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting)
102-
return &WarpRoutingService{Proxy: warpRoutingService}
100+
return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}}
103101
}
104102

105103
// Get a single origin service from the CLI/config.
@@ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
108106
return new(helloWorld), nil
109107
}
110108
if c.IsSet(config.BastionFlag) {
111-
return newBridgeService(nil, ServiceBastion), nil
109+
return newBastionService(), nil
112110
}
113111
if c.IsSet("url") {
114112
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
@@ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
120118
url: originURL,
121119
}, nil
122120
}
123-
return newSingleTCPService(originURL), nil
121+
return newTCPOverWSService(originURL), nil
124122
}
125123
if c.IsSet("unix-socket") {
126124
path, err := config.ValidateUnixSocket(c)
@@ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
182180
// overwrite the localService.URL field when `start` is called. So,
183181
// leave the URL field empty for now.
184182
cfg.BastionMode = true
185-
service = newBridgeService(nil, ServiceBastion)
183+
service = newBastionService()
186184
} else {
187185
// Validate URL services
188186
u, err := url.Parse(r.Service)
@@ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
200198
if isHTTPService(u) {
201199
service = &httpService{url: u}
202200
} else {
203-
service = newSingleTCPService(u)
201+
service = newTCPOverWSService(u)
204202
}
205203
}
206204

ingress/ingress_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,12 @@ ingress:
238238
want: []Rule{
239239
{
240240
Hostname: "tcp.foo.com",
241-
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
241+
Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")),
242242
Config: defaultConfig,
243243
},
244244
{
245245
Hostname: "tcp2.foo.com",
246-
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
246+
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")),
247247
Config: defaultConfig,
248248
},
249249
{
@@ -260,7 +260,7 @@ ingress:
260260
`},
261261
want: []Rule{
262262
{
263-
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
263+
Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")),
264264
Config: defaultConfig,
265265
},
266266
},
@@ -273,7 +273,7 @@ ingress:
273273
`},
274274
want: []Rule{
275275
{
276-
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
276+
Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")),
277277
Config: defaultConfig,
278278
},
279279
},
@@ -286,7 +286,7 @@ ingress:
286286
`},
287287
want: []Rule{
288288
{
289-
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
289+
Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")),
290290
Config: defaultConfig,
291291
},
292292
},
@@ -299,7 +299,7 @@ ingress:
299299
`},
300300
want: []Rule{
301301
{
302-
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
302+
Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")),
303303
Config: defaultConfig,
304304
},
305305
},
@@ -316,7 +316,7 @@ ingress:
316316
want: []Rule{
317317
{
318318
Hostname: "bastion.foo.com",
319-
Service: newBridgeService(nil, ServiceBastion),
319+
Service: newBastionService(),
320320
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
321321
},
322322
{
@@ -336,7 +336,7 @@ ingress:
336336
want: []Rule{
337337
{
338338
Hostname: "bastion.foo.com",
339-
Service: newBridgeService(nil, ServiceBastion),
339+
Service: newBastionService(),
340340
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
341341
},
342342
{

ingress/origin_connection.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package ingress
22

33
import (
4+
"context"
5+
"crypto/tls"
46
"io"
57
"net"
68
"net/http"
79

8-
"github.com/cloudflare/cloudflared/connection"
910
"github.com/cloudflare/cloudflared/websocket"
1011
gws "github.com/gorilla/websocket"
1112
"github.com/rs/zerolog"
@@ -15,9 +16,8 @@ import (
1516
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
1617
type OriginConnection interface {
1718
// Stream should generally be implemented as a bidirectional io.Copy.
18-
Stream(tunnelConn io.ReadWriter, log *zerolog.Logger)
19+
Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
1920
Close()
20-
Type() connection.Type
2121
}
2222

2323
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
@@ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
5454

5555
// tcpConnection is an OriginConnection that directly streams to raw TCP.
5656
type tcpConnection struct {
57-
conn net.Conn
58-
streamHandler streamHandlerFunc
57+
conn net.Conn
5958
}
6059

61-
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
62-
tc.streamHandler(tunnelConn, tc.conn, log)
60+
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
61+
Stream(tunnelConn, tc.conn, log)
6362
}
6463

6564
func (tc *tcpConnection) Close() {
6665
tc.conn.Close()
6766
}
6867

69-
func (*tcpConnection) Type() connection.Type {
70-
return connection.TypeTCP
68+
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
69+
type tcpOverWSConnection struct {
70+
conn net.Conn
71+
streamHandler streamHandlerFunc
72+
}
73+
74+
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
75+
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
76+
}
77+
78+
func (wc *tcpOverWSConnection) Close() {
79+
wc.conn.Close()
7180
}
7281

73-
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
74-
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
82+
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
7583
type wsConnection struct {
7684
wsConn *gws.Conn
7785
resp *http.Response
7886
}
7987

80-
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
88+
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
8189
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
8290
}
8391

@@ -86,13 +94,9 @@ func (wsc *wsConnection) Close() {
8694
wsc.wsConn.Close()
8795
}
8896

89-
func (wsc *wsConnection) Type() connection.Type {
90-
return connection.TypeWebsocket
91-
}
92-
93-
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
97+
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
9498
d := &gws.Dialer{
95-
TLSClientConfig: transport.TLSClientConfig,
99+
TLSClientConfig: clientTLSConfig,
96100
}
97101
wsConn, resp, err := websocket.ClientConnect(r, d)
98102
if err != nil {

ingress/origin_connection_test.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package ingress
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
"time"
12+
13+
"github.com/cloudflare/cloudflared/logger"
14+
"github.com/gobwas/ws/wsutil"
15+
"github.com/gorilla/websocket"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
"golang.org/x/sync/errgroup"
19+
)
20+
21+
const (
22+
testStreamTimeout = time.Second * 3
23+
)
24+
25+
var (
26+
testLogger = logger.Create(nil)
27+
testMessage = []byte("TestStreamOriginConnection")
28+
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
29+
)
30+
31+
func TestStreamTCPConnection(t *testing.T) {
32+
cfdConn, originConn := net.Pipe()
33+
tcpConn := tcpConnection{
34+
conn: cfdConn,
35+
}
36+
37+
eyeballConn, edgeConn := net.Pipe()
38+
39+
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
40+
defer cancel()
41+
42+
errGroup, ctx := errgroup.WithContext(ctx)
43+
errGroup.Go(func() error {
44+
_, err := eyeballConn.Write(testMessage)
45+
46+
readBuffer := make([]byte, len(testResponse))
47+
_, err = eyeballConn.Read(readBuffer)
48+
require.NoError(t, err)
49+
50+
require.Equal(t, testResponse, readBuffer)
51+
52+
return nil
53+
})
54+
errGroup.Go(func() error {
55+
echoTCPOrigin(t, originConn)
56+
originConn.Close()
57+
return nil
58+
})
59+
60+
tcpConn.Stream(ctx, edgeConn, testLogger)
61+
require.NoError(t, errGroup.Wait())
62+
}
63+
64+
func TestStreamWSOverTCPConnection(t *testing.T) {
65+
cfdConn, originConn := net.Pipe()
66+
tcpOverWSConn := tcpOverWSConnection{
67+
conn: cfdConn,
68+
streamHandler: DefaultStreamHandler,
69+
}
70+
71+
eyeballConn, edgeConn := net.Pipe()
72+
73+
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
74+
defer cancel()
75+
76+
errGroup, ctx := errgroup.WithContext(ctx)
77+
errGroup.Go(func() error {
78+
echoWSEyeball(t, eyeballConn)
79+
return nil
80+
})
81+
errGroup.Go(func() error {
82+
echoTCPOrigin(t, originConn)
83+
originConn.Close()
84+
return nil
85+
})
86+
87+
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
88+
require.NoError(t, errGroup.Wait())
89+
}
90+
91+
func TestStreamWSConnection(t *testing.T) {
92+
eyeballConn, edgeConn := net.Pipe()
93+
94+
origin := echoWSOrigin(t)
95+
defer origin.Close()
96+
97+
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
98+
require.NoError(t, err)
99+
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
100+
101+
clientTLSConfig := &tls.Config{
102+
InsecureSkipVerify: true,
103+
}
104+
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
105+
require.NoError(t, err)
106+
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
107+
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
108+
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
109+
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
110+
111+
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
112+
defer cancel()
113+
114+
errGroup, ctx := errgroup.WithContext(ctx)
115+
errGroup.Go(func() error {
116+
echoWSEyeball(t, eyeballConn)
117+
return nil
118+
})
119+
120+
wsConn.Stream(ctx, edgeConn, testLogger)
121+
require.NoError(t, errGroup.Wait())
122+
}
123+
124+
func echoWSEyeball(t *testing.T, conn net.Conn) {
125+
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
126+
127+
readMsg, err := wsutil.ReadServerBinary(conn)
128+
require.NoError(t, err)
129+
130+
require.Equal(t, testResponse, readMsg)
131+
132+
require.NoError(t, conn.Close())
133+
}
134+
135+
func echoWSOrigin(t *testing.T) *httptest.Server {
136+
var upgrader = websocket.Upgrader{
137+
ReadBufferSize: 10,
138+
WriteBufferSize: 10,
139+
}
140+
141+
ws := func(w http.ResponseWriter, r *http.Request) {
142+
header := make(http.Header)
143+
for k, vs := range r.Header {
144+
if k == "Test-Cloudflared-Echo" {
145+
header[k] = vs
146+
}
147+
}
148+
conn, err := upgrader.Upgrade(w, r, header)
149+
require.NoError(t, err)
150+
defer conn.Close()
151+
152+
for {
153+
messageType, p, err := conn.ReadMessage()
154+
if err != nil {
155+
return
156+
}
157+
require.Equal(t, testMessage, p)
158+
if err := conn.WriteMessage(messageType, testResponse); err != nil {
159+
return
160+
}
161+
}
162+
}
163+
164+
// NewTLSServer starts the server in another thread
165+
return httptest.NewTLSServer(http.HandlerFunc(ws))
166+
}
167+
168+
func echoTCPOrigin(t *testing.T, conn net.Conn) {
169+
readBuffer := make([]byte, len(testMessage))
170+
_, err := conn.Read(readBuffer)
171+
assert.NoError(t, err)
172+
173+
assert.Equal(t, testMessage, readBuffer)
174+
175+
_, err = conn.Write(testResponse)
176+
assert.NoError(t, err)
177+
}

0 commit comments

Comments
 (0)