Skip to content

Commit 6b86f81

Browse files
committed
TUN-3403: Unit test for origin/proxy to test serving HTTP and Websocket
1 parent a490443 commit 6b86f81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+7630
-10
lines changed

connection/header.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@ import (
88
)
99

1010
const (
11-
responseMetaHeaderField = "cf-cloudflared-response-meta"
12-
responseSourceCloudflared = "cloudflared"
13-
responseSourceOrigin = "origin"
11+
responseMetaHeaderField = "cf-cloudflared-response-meta"
1412
)
1513

1614
var (
1715
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
1816
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(responseMetaHeaderField)
19-
responseMetaHeaderCfd = mustInitRespMetaHeader(responseSourceCloudflared)
20-
responseMetaHeaderOrigin = mustInitRespMetaHeader(responseSourceOrigin)
17+
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
18+
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
2119
)
2220

2321
type responseMetaHeader struct {

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ require (
2727
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10
2828
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0
2929
github.com/go-sql-driver/mysql v1.5.0
30+
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
31+
github.com/gobwas/pool v0.2.1 // indirect
32+
github.com/gobwas/ws v1.0.4
3033
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
3134
github.com/google/go-cmp v0.5.2 // indirect
3235
github.com/google/uuid v1.1.2

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
233233
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
234234
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
235235
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
236+
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
237+
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
238+
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
239+
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
240+
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
241+
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
236242
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
237243
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
238244
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=

hello/hello.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"github.com/cloudflare/cloudflared/tlsconfig"
1919
)
2020

21+
const (
22+
UptimeRoute = "/uptime"
23+
WSRoute = "/ws"
24+
)
25+
2126
type templateData struct {
2227
ServerName string
2328
Request *http.Request
@@ -104,8 +109,8 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
104109
}
105110

106111
muxer := http.NewServeMux()
107-
muxer.HandleFunc("/uptime", uptimeHandler(time.Now()))
108-
muxer.HandleFunc("/ws", websocketHandler(logger, upgrader))
112+
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
113+
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
109114
muxer.HandleFunc("/", rootHandler(serverName))
110115
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
111116
go func() {

origin/proxy.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package origin
22

33
import (
44
"bufio"
5+
"context"
56
"crypto/tls"
67
"io"
78
"net/http"
@@ -112,20 +113,24 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
112113

113114
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
114115
c.setHostHeader(req)
115-
116116
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
117117
if err != nil {
118118
return nil, err
119119
}
120-
defer conn.Close()
120+
121+
serveCtx, cancel := context.WithCancel(req.Context())
122+
defer cancel()
123+
go func() {
124+
<-serveCtx.Done()
125+
conn.Close()
126+
}()
121127
err = w.WriteRespHeaders(resp)
122128
if err != nil {
123129
return nil, errors.Wrap(err, "Error writing response header")
124130
}
125131
// Copy to/from stream to the undelying connection. Use the underlying
126132
// connection because cloudflared doesn't operate on the message themselves
127133
websocket.Stream(conn.UnderlyingConn(), w)
128-
129134
return resp, nil
130135
}
131136

origin/proxy_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package origin
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"crypto/tls"
7+
"crypto/x509"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"net/http/httptest"
12+
"net/url"
13+
"sync"
14+
"testing"
15+
16+
"github.com/cloudflare/cloudflared/connection"
17+
"github.com/cloudflare/cloudflared/hello"
18+
"github.com/cloudflare/cloudflared/logger"
19+
"github.com/cloudflare/cloudflared/tlsconfig"
20+
21+
"github.com/gobwas/ws/wsutil"
22+
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
24+
)
25+
26+
type mockHTTPRespWriter struct {
27+
*httptest.ResponseRecorder
28+
}
29+
30+
func newMockHTTPRespWriter() *mockHTTPRespWriter {
31+
return &mockHTTPRespWriter{
32+
httptest.NewRecorder(),
33+
}
34+
}
35+
36+
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
37+
w.WriteHeader(resp.StatusCode)
38+
for header, val := range resp.Header {
39+
w.Header()[header] = val
40+
}
41+
return nil
42+
}
43+
44+
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) {
45+
w.WriteHeader(http.StatusBadGateway)
46+
}
47+
48+
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
49+
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
50+
}
51+
52+
type mockWSRespWriter struct {
53+
*mockHTTPRespWriter
54+
writeNotification chan []byte
55+
reader io.Reader
56+
}
57+
58+
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter {
59+
return &mockWSRespWriter{
60+
httpRespWriter,
61+
make(chan []byte),
62+
reader,
63+
}
64+
}
65+
66+
func (w *mockWSRespWriter) Write(data []byte) (int, error) {
67+
w.writeNotification <- data
68+
return len(data), nil
69+
}
70+
71+
func (w *mockWSRespWriter) respBody() io.ReadWriter {
72+
data := <-w.writeNotification
73+
return bytes.NewBuffer(data)
74+
}
75+
76+
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
77+
return w.reader.Read(data)
78+
}
79+
80+
func TestProxy(t *testing.T) {
81+
logger, err := logger.New()
82+
require.NoError(t, err)
83+
// let runtime pick an available port
84+
listener, err := hello.CreateTLSListener("127.0.0.1:0")
85+
require.NoError(t, err)
86+
87+
originURL := &url.URL{
88+
Scheme: "https",
89+
Host: listener.Addr().String(),
90+
}
91+
originCA := x509.NewCertPool()
92+
helloCert, err := tlsconfig.GetHelloCertificateX509()
93+
require.NoError(t, err)
94+
originCA.AddCert(helloCert)
95+
clientTLS := &tls.Config{
96+
RootCAs: originCA,
97+
}
98+
proxyConfig := &ProxyConfig{
99+
Client: &http.Transport{
100+
TLSClientConfig: clientTLS,
101+
},
102+
URL: originURL,
103+
TLSConfig: clientTLS,
104+
}
105+
106+
ctx, cancel := context.WithCancel(context.Background())
107+
108+
go func() {
109+
hello.StartHelloWorldServer(logger, listener, ctx.Done())
110+
}()
111+
112+
client := NewClient(proxyConfig, logger)
113+
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
114+
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
115+
cancel()
116+
}
117+
118+
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
119+
return func(t *testing.T) {
120+
respWriter := newMockHTTPRespWriter()
121+
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
122+
require.NoError(t, err)
123+
124+
err = client.Proxy(respWriter, req, false)
125+
require.NoError(t, err)
126+
127+
assert.Equal(t, http.StatusOK, respWriter.Code)
128+
}
129+
}
130+
131+
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) {
132+
return func(t *testing.T) {
133+
// WSRoute is a websocket echo handler
134+
ctx, cancel := context.WithCancel(context.Background())
135+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
136+
137+
readPipe, writePipe := io.Pipe()
138+
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe)
139+
140+
var wg sync.WaitGroup
141+
wg.Add(1)
142+
go func() {
143+
defer wg.Done()
144+
err = client.Proxy(respWriter, req, true)
145+
require.NoError(t, err)
146+
147+
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
148+
}()
149+
150+
msg := []byte("test websocket")
151+
err = wsutil.WriteClientText(writePipe, msg)
152+
require.NoError(t, err)
153+
154+
// ReadServerText reads next data message from rw, considering that caller represents client side.
155+
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
156+
require.NoError(t, err)
157+
require.Equal(t, msg, returnedMsg)
158+
159+
err = wsutil.WriteClientBinary(writePipe, msg)
160+
require.NoError(t, err)
161+
162+
returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody())
163+
require.NoError(t, err)
164+
require.Equal(t, msg, returnedMsg)
165+
166+
cancel()
167+
wg.Wait()
168+
}
169+
}

vendor/github.com/gobwas/httphead/LICENSE

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/gobwas/httphead/README.md

Lines changed: 63 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)