Skip to content

Commit ed57ee6

Browse files
sudarshan-reddynmldiegues
authored andcommitted
TUN-3853: Respond with ws headers from the origin service rather than generating our own
1 parent 9c298e4 commit ed57ee6

File tree

4 files changed

+97
-34
lines changed

4 files changed

+97
-34
lines changed

ingress/origin_connection.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type {
9090
return connection.TypeWebsocket
9191
}
9292

93-
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
93+
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
9494
d := &gws.Dialer{
9595
TLSClientConfig: transport.TLSClientConfig,
9696
}
9797
wsConn, resp, err := websocket.ClientConnect(r, d)
9898
if err != nil {
99-
return nil, err
99+
return nil, nil, err
100100
}
101101
return &wsConnection{
102102
wsConn,
103103
resp,
104-
}, nil
104+
}, resp, nil
105105
}

ingress/origin_proxy.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ type HTTPOriginProxy interface {
2121

2222
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
2323
type StreamBasedOriginProxy interface {
24-
EstablishConnection(r *http.Request) (OriginConnection, error)
24+
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
2525
}
2626

2727
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
2828
return o.transport.RoundTrip(req)
2929
}
3030

3131
// TODO: TUN-3636: establish connection to origins over UDS
32-
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
33-
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
32+
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
33+
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
3434
}
3535

3636
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
4040
return o.transport.RoundTrip(req)
4141
}
4242

43-
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
43+
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
4444
req.URL.Host = o.url.Host
4545
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
4646
return newWSConnection(o.transport, req)
@@ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
5353
return o.transport.RoundTrip(req)
5454
}
5555

56-
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
56+
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
5757
req.URL.Host = o.server.Addr().String()
5858
req.URL.Scheme = "wss"
5959
return newWSConnection(o.transport, req)
@@ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
6363
return o.resp, nil
6464
}
6565

66-
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
66+
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
6767
dest, err := o.destination(r)
6868
if err != nil {
69-
return nil, err
69+
return nil, nil, err
7070
}
71-
return o.client.connect(r, dest)
71+
conn, err := o.client.connect(r, dest)
72+
return conn, nil, err
7273
}
7374

7475
// getRequestHost returns the host of the http.Request.
@@ -102,8 +103,10 @@ func removePath(dest string) string {
102103
return strings.SplitN(dest, "/", 2)[0]
103104
}
104105

105-
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
106-
return o.client.connect(r, o.dest)
106+
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
107+
conn, err := o.client.connect(r, o.dest)
108+
return conn, nil, err
109+
107110
}
108111

109112
type tcpClient struct {

origin/proxy.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,22 @@ func (p *proxy) proxyConnection(
166166
sourceConnectionType connection.Type,
167167
connectionProxy ingress.StreamBasedOriginProxy,
168168
) (*http.Response, error) {
169-
originConn, err := connectionProxy.EstablishConnection(req)
169+
originConn, connectionResp, err := connectionProxy.EstablishConnection(req)
170170
if err != nil {
171171
return nil, err
172172
}
173173

174174
var eyeballConn io.ReadWriter = w
175175
respHeader := http.Header{}
176+
if connectionResp != nil {
177+
respHeader = connectionResp.Header
178+
}
176179
if sourceConnectionType == connection.TypeWebsocket {
177180
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
178181
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
179182
if originConn.Type() != sourceConnectionType {
180183
eyeballConn = wsReadWriter
181184
}
182-
respHeader = websocket.NewResponseHeader(req)
183185
}
184186
status := http.StatusSwitchingProtocols
185187
resp := &http.Response{

origin/proxy_test.go

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,26 @@ func TestConnections(t *testing.T) {
411411
originService func(*testing.T, net.Listener)
412412
eyeballService connection.ResponseWriter
413413
connectionType connection.Type
414+
requestHeaders http.Header
414415
wantMessage []byte
416+
wantHeaders http.Header
415417
}{
416418
{
417419
name: "ws-ws proxy",
418420
ingressServicePrefix: "ws://",
419421
originService: runEchoWSService,
420422
eyeballService: newWSRespWriter([]byte("test1"), replayer),
421423
connectionType: connection.TypeWebsocket,
422-
wantMessage: []byte("test1"),
424+
requestHeaders: map[string][]string{
425+
"Test-Cloudflared-Echo": []string{"Echo"},
426+
},
427+
wantMessage: []byte("echo-test1"),
428+
wantHeaders: map[string][]string{
429+
"Connection": []string{"Upgrade"},
430+
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
431+
"Upgrade": []string{"websocket"},
432+
"Test-Cloudflared-Echo": []string{"Echo"},
433+
},
423434
},
424435
{
425436
name: "tcp-tcp proxy",
@@ -430,15 +441,25 @@ func TestConnections(t *testing.T) {
430441
replayer,
431442
),
432443
connectionType: connection.TypeTCP,
433-
wantMessage: []byte("echo-test2"),
444+
requestHeaders: map[string][]string{
445+
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
446+
},
447+
wantMessage: []byte("echo-test2"),
448+
wantHeaders: http.Header{},
434449
},
435450
{
436451
name: "tcp-ws proxy",
437452
ingressServicePrefix: "ws://",
438453
originService: runEchoWSService,
439454
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
440-
connectionType: connection.TypeTCP,
441-
wantMessage: []byte("test3"),
455+
requestHeaders: map[string][]string{
456+
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
457+
},
458+
connectionType: connection.TypeTCP,
459+
wantMessage: []byte("echo-test3"),
460+
// We expect no headers here because they are sent back via
461+
// the stream.
462+
wantHeaders: http.Header{},
442463
},
443464
{
444465
name: "ws-tcp proxy",
@@ -447,14 +468,12 @@ func TestConnections(t *testing.T) {
447468
eyeballService: newWSRespWriter([]byte("test4"), replayer),
448469
connectionType: connection.TypeWebsocket,
449470
wantMessage: []byte("echo-test4"),
471+
wantHeaders: http.Header{},
450472
},
451473
}
452474

453475
for _, test := range tests {
454476
t.Run(test.name, func(t *testing.T) {
455-
if test.skip {
456-
t.Skip("todo: skipping a failing test. THis should be fixed before merge")
457-
}
458477
ctx, cancel := context.WithCancel(context.Background())
459478
ln, err := net.Listen("tcp", "127.0.0.1:0")
460479
require.NoError(t, err)
@@ -466,29 +485,41 @@ func TestConnections(t *testing.T) {
466485
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
467486
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
468487
require.NoError(t, err)
469-
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
488+
reqHeaders := make(http.Header)
489+
for k, vs := range test.requestHeaders {
490+
reqHeaders[k] = vs
491+
}
492+
req.Header = reqHeaders
470493

471494
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
472495
go func() {
473496
resp := pipedWS.roundtrip(test.ingressServicePrefix + ln.Addr().String())
474497
replayer.Write(resp)
475498
}()
476499
}
500+
477501
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
478502
require.NoError(t, err)
479503

480504
cancel()
481505
assert.Equal(t, test.wantMessage, replayer.Bytes())
506+
respPrinter := test.eyeballService.(responsePrinter)
507+
assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders())
482508
replayer.rw.Reset()
483509
})
484510
}
485511
}
486512

513+
type responsePrinter interface {
514+
printRespHeaders() http.Header
515+
}
516+
487517
type pipedWSWriter struct {
488518
dialer gorillaWS.Dialer
489519
wsConn net.Conn
490520
pipedConn net.Conn
491521
respWriter connection.ResponseWriter
522+
respHeaders http.Header
492523
messageToWrite []byte
493524
}
494525

@@ -547,14 +578,21 @@ func (p *pipedWSWriter) WriteErrorResponse() {
547578
}
548579

549580
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
581+
p.respHeaders = header
550582
return nil
551583
}
552584

585+
// printRespHeaders is a test function to read respHeaders
586+
func (p *pipedWSWriter) printRespHeaders() http.Header {
587+
return p.respHeaders
588+
}
589+
553590
type wsRespWriter struct {
554-
w io.Writer
555-
pr *io.PipeReader
556-
pw *io.PipeWriter
557-
code int
591+
w io.Writer
592+
pr *io.PipeReader
593+
pw *io.PipeWriter
594+
respHeaders http.Header
595+
code int
558596
}
559597

560598
// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
@@ -589,13 +627,19 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
589627
}
590628

591629
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
630+
w.respHeaders = header
592631
w.code = status
593632
return nil
594633
}
595634

596635
func (w *wsRespWriter) WriteErrorResponse() {
597636
}
598637

638+
// printRespHeaders is a test function to read respHeaders
639+
func (w *wsRespWriter) printRespHeaders() http.Header {
640+
return w.respHeaders
641+
}
642+
599643
func runEchoTCPService(t *testing.T, l net.Listener) {
600644
go func() {
601645
for {
@@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) {
628672
}
629673

630674
var ws = func(w http.ResponseWriter, r *http.Request) {
631-
conn, err := upgrader.Upgrade(w, r, nil)
675+
header := make(http.Header)
676+
for k, vs := range r.Header {
677+
if k == "Test-Cloudflared-Echo" {
678+
header[k] = vs
679+
}
680+
}
681+
conn, err := upgrader.Upgrade(w, r, header)
632682
require.NoError(t, err)
633683
defer conn.Close()
634684

@@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) {
637687
if err != nil {
638688
return
639689
}
640-
641-
if err := conn.WriteMessage(messageType, p); err != nil {
690+
data := []byte("echo-")
691+
data = append(data, p...)
692+
if err := conn.WriteMessage(messageType, data); err != nil {
642693
return
643694
}
644695
}
@@ -672,10 +723,11 @@ type tcpWrappedWs struct {
672723
}
673724

674725
type mockTCPRespWriter struct {
675-
w io.Writer
676-
pr io.Reader
677-
pw *io.PipeWriter
678-
code int
726+
w io.Writer
727+
pr io.Reader
728+
pw *io.PipeWriter
729+
respHeaders http.Header
730+
code int
679731
}
680732

681733
func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter {
@@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() {
701753
}
702754

703755
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
756+
m.respHeaders = header
704757
m.code = status
705758
return nil
706759
}
760+
761+
// printRespHeaders is a test function to read respHeaders
762+
func (m *mockTCPRespWriter) printRespHeaders() http.Header {
763+
return m.respHeaders
764+
}

0 commit comments

Comments
 (0)