diff --git a/client_proxy_server_test.go b/client_proxy_server_test.go index c8e6850f..71475e6a 100644 --- a/client_proxy_server_test.go +++ b/client_proxy_server_test.go @@ -43,13 +43,13 @@ func TestHTTPProxyAndBackend(t *testing.T) { websocketTLS := false proxyTLS := false // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -80,13 +80,13 @@ func TestHTTPProxyWithNetDial(t *testing.T) { websocketTLS := false proxyTLS := false // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -125,13 +125,13 @@ func TestHTTPProxyWithNetDialContext(t *testing.T) { websocketTLS := false proxyTLS := false // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -171,13 +171,13 @@ func TestHTTPProxyWithHTTPSBackend(t *testing.T) { websocketTLS := true proxyTLS := false // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -219,13 +219,13 @@ func TestHTTPSProxyAndBackend(t *testing.T) { websocketTLS := true proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -257,13 +257,13 @@ func TestHTTPSProxyUsingNetDial(t *testing.T) { websocketTLS := true proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -303,13 +303,13 @@ func TestHTTPSProxyUsingNetDialContext(t *testing.T) { websocketTLS := true proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -349,13 +349,13 @@ func TestHTTPSProxyUsingNetDialTLSContext(t *testing.T) { websocketTLS := true proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -408,13 +408,13 @@ func TestHTTPSProxyHTTPBackend(t *testing.T) { websocketTLS := false proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -446,13 +446,13 @@ func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend(t *testing.T) { websocketTLS := false proxyTLS := true // Start the websocket server, which echoes data back to sender. - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } // Start the proxy server. - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -486,12 +486,12 @@ func TestTLSValidationErrors(t *testing.T) { // Both websocket and proxy servers are started with TLS. websocketTLS := true proxyTLS := true - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } - proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + proxyServer, proxyServerURL, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -534,7 +534,7 @@ func TestTLSValidationErrors(t *testing.T) { } func TestProxyFnErrorIsPropagated(t *testing.T) { - websocketServer, websocketURL, err := newWebsocketServer(false) + websocketServer, websocketURL, err := newWebsocketServer(t, false) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) @@ -560,12 +560,12 @@ func TestProxyFnNilMeansNoProxy(t *testing.T) { // Both websocket and proxy servers are started. websocketTLS := false proxyTLS := false - websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + websocketServer, websocketURL, err := newWebsocketServer(t, websocketTLS) defer websocketServer.Close() if err != nil { t.Fatalf("error starting websocket server: %v", err) } - proxyServer, _, err := newProxyServer(proxyTLS) + proxyServer, _, err := newProxyServer(t, proxyTLS) defer proxyServer.Close() if err != nil { t.Fatalf("error starting proxy server: %v", err) @@ -623,51 +623,54 @@ func (ts *testServer) Close() { } } -// websocketEchoHandler upgrades the connection associated with the request, and -// echoes binary messages read off the websocket connection back to the client. -var websocketEchoHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - upgrader := Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true // Accepting all requests - }, - Subprotocols: []string{ - subprotocolV1, - subprotocolV2, - }, - } - wsConn, err := upgrader.Upgrade(w, req, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - defer wsConn.Close() - for { - writer, err := wsConn.NextWriter(BinaryMessage) - if err != nil { - break +// newWebsocketEchoHandler returns a handler that upgrades the connection associated with the request, +// and echoes binary messages read off the websocket connection back to the client. +func newWebsocketEchoHandler(t *testing.T) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + upgrader := Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{ + subprotocolV1, + subprotocolV2, + }, } - messageType, reader, err := wsConn.NextReader() + wsConn, err := upgrader.Upgrade(w, req, nil) if err != nil { - break - } - if messageType != BinaryMessage { - http.Error(w, "websocket reader not binary message type", - http.StatusInternalServerError) + t.Logf("websocketEchoHandler Upgrade: %v, %#v", err, req) + return } - _, err = io.Copy(writer, reader) - if err != nil { - http.Error(w, "websocket server io copy error", - http.StatusInternalServerError) + defer wsConn.Close() + for { + writer, err := wsConn.NextWriter(BinaryMessage) + if err != nil { + break + } + messageType, reader, err := wsConn.NextReader() + if err != nil { + break + } + if messageType != BinaryMessage { + t.Log("websocket reader not binary message type") + break + } + _, err = io.Copy(writer, reader) + if err != nil { + t.Log("websocket server io copy error") + break + } } - } -}) + }) +} // Returns a test backend websocket server as well as the URL pointing // to the server, or an error if one occurred. Sets up a TLS endpoint // on the server if the passed "tlsServer" is true. // func newWebsocketServer(tlsServer bool) (*httptest.Server, *url.URL, error) { -func newWebsocketServer(tlsServer bool) (closer, *url.URL, error) { +func newWebsocketServer(t *testing.T, tlsServer bool) (closer, *url.URL, error) { // Start the websocket server, which echoes data back to sender. - websocketServer := httptest.NewUnstartedServer(websocketEchoHandler) + websocketServer := httptest.NewUnstartedServer(newWebsocketEchoHandler(t)) if tlsServer { websocketKeyPair, err := tls.X509KeyPair(websocketServerCert, websocketServerKey) if err != nil { @@ -695,45 +698,59 @@ func newWebsocketServer(tlsServer bool) (closer, *url.URL, error) { // proxyHandler creates a full duplex streaming connection between the client // (hijacking the http request connection), and an "upstream" dialed connection // to the "Host". Creates two goroutines to copy between connections in each direction. -var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // Validate the CONNECT method. - if req.Method != http.MethodConnect { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - // Dial upstream server. - upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer upstream.Close() - // Return 200 OK to client. - w.WriteHeader(http.StatusOK) - // Hijack client connection. - client, _, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer client.Close() - // Create duplex streaming between client and upstream connections. - done := make(chan struct{}, 2) - go func() { - _, _ = io.Copy(upstream, client) - done <- struct{}{} - }() - go func() { - _, _ = io.Copy(client, upstream) - done <- struct{}{} - }() - <-done -}) +func newProxyHandler(t *testing.T) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Validate the CONNECT method. + if req.Method != http.MethodConnect { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + // Dial upstream server. + upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer upstream.Close() + // Return 200 OK to client. + w.WriteHeader(http.StatusOK) + // Hijack client connection. + client, brw, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer client.Close() + + // flush any buffered reads/writes + if err := brw.Flush(); err != nil { + t.Logf("Failed to flush pending writes to client, host=%s: %v\n", req.Host, err) + return + } + if _, err := io.Copy(upstream, io.LimitReader(brw, int64(brw.Reader.Buffered()))); err != nil { + t.Logf("Failed to flush buffered reads to server, host=%s: %v\n", req.Host, err) + return + } + + // Create duplex streaming between client and upstream connections. + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, client) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(client, upstream) + done <- struct{}{} + }() + <-done + }) +} // Returns a new test HTTP server, as well as the URL to that server, or // an error if one occurred. numProxyCalls keeps track of the number of // times the proxy handler was called with this server. -func newProxyServer(tlsServer bool) (counter, *url.URL, error) { +func newProxyServer(t *testing.T, tlsServer bool) (counter, *url.URL, error) { + proxyHandler := newProxyHandler(t) // Start the proxy server, keeping track of how many times the handler is called. ts := &testServer{} proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -752,6 +769,7 @@ func newProxyServer(tlsServer bool) (counter, *url.URL, error) { } else { proxyServer.Start() } + ts.server = proxyServer proxyURL, err := url.Parse(proxyServer.URL) if err != nil { return nil, nil, err