Skip to content

Commit d576951

Browse files
committed
TUN-3489: Add unit tests to cover proxy logic in connection package of cloudflared
1 parent 5974fb4 commit d576951

File tree

9 files changed

+754
-92
lines changed

9 files changed

+754
-92
lines changed

connection/connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type OriginClient interface {
4444

4545
type ResponseWriter interface {
4646
WriteRespHeaders(*http.Response) error
47-
WriteErrorResponse(error)
47+
WriteErrorResponse()
4848
io.ReadWriter
4949
}
5050

connection/connection_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package connection
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"net/url"
8+
"time"
9+
10+
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
11+
"github.com/cloudflare/cloudflared/logger"
12+
"github.com/gobwas/ws/wsutil"
13+
)
14+
15+
const (
16+
largeFileSize = 2 * 1024 * 1024
17+
)
18+
19+
var (
20+
testConfig = &Config{
21+
OriginClient: &mockOriginClient{},
22+
GracePeriod: time.Millisecond * 100,
23+
}
24+
testLogger, _ = logger.New()
25+
testOriginURL = &url.URL{
26+
Scheme: "https",
27+
Host: "connectiontest.argotunnel.com",
28+
}
29+
testTunnelEventChan = make(chan ui.TunnelEvent)
30+
testObserver = &Observer{
31+
testLogger,
32+
m,
33+
testTunnelEventChan,
34+
}
35+
testLargeResp = make([]byte, largeFileSize)
36+
)
37+
38+
type testRequest struct {
39+
name string
40+
endpoint string
41+
expectedStatus int
42+
expectedBody []byte
43+
isProxyError bool
44+
}
45+
46+
type mockOriginClient struct {
47+
}
48+
49+
func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
50+
if isWebsocket {
51+
return wsEndpoint(w, r)
52+
}
53+
switch r.URL.Path {
54+
case "/ok":
55+
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
56+
case "/large_file":
57+
originRespEndpoint(w, http.StatusOK, testLargeResp)
58+
case "/400":
59+
originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
60+
case "/500":
61+
originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
62+
case "/error":
63+
return fmt.Errorf("Failed to proxy to origin")
64+
default:
65+
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
66+
}
67+
return nil
68+
}
69+
70+
type nowriter struct {
71+
io.Reader
72+
}
73+
74+
func (nowriter) Write(p []byte) (int, error) {
75+
return 0, fmt.Errorf("Writer not implemented")
76+
}
77+
78+
func wsEndpoint(w ResponseWriter, r *http.Request) error {
79+
resp := &http.Response{
80+
StatusCode: http.StatusSwitchingProtocols,
81+
}
82+
w.WriteRespHeaders(resp)
83+
clientReader := nowriter{r.Body}
84+
go func() {
85+
for {
86+
data, err := wsutil.ReadClientText(clientReader)
87+
if err != nil {
88+
return
89+
}
90+
if err := wsutil.WriteServerText(w, data); err != nil {
91+
return
92+
}
93+
}
94+
}()
95+
<-r.Context().Done()
96+
return nil
97+
}
98+
99+
func originRespEndpoint(w ResponseWriter, status int, data []byte) {
100+
resp := &http.Response{
101+
StatusCode: status,
102+
}
103+
w.WriteRespHeaders(resp)
104+
w.Write(data)
105+
}
106+
107+
type mockConnectedFuse struct{}
108+
109+
func (mcf mockConnectedFuse) Connected() {}
110+
111+
func (mcf mockConnectedFuse) IsConnected() bool {
112+
return true
113+
}

connection/h2mux.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
8888
return err
8989
}
9090
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
91-
defer rpcClient.close()
91+
defer rpcClient.Close()
9292

93-
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
93+
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
9494
return err
9595
}
9696
connectedFuse.Connected()
@@ -177,11 +177,16 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
177177

178178
req, reqErr := h.newRequest(stream)
179179
if reqErr != nil {
180-
respWriter.WriteErrorResponse(reqErr)
180+
respWriter.WriteErrorResponse()
181181
return reqErr
182182
}
183183

184-
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
184+
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
185+
if err != nil {
186+
respWriter.WriteErrorResponse()
187+
return err
188+
}
189+
return nil
185190
}
186191

187192
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
@@ -206,7 +211,7 @@ func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
206211
return rp.WriteHeaders(headers)
207212
}
208213

209-
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
214+
func (rp *h2muxRespWriter) WriteErrorResponse() {
210215
rp.WriteHeaders([]h2mux.Header{
211216
{Name: ":status", Value: "502"},
212217
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},

0 commit comments

Comments
 (0)