Skip to content

Commit e7eb8b8

Browse files
Add support for TLS WebSocket proxy
1 parent 45524e3 commit e7eb8b8

File tree

2 files changed

+166
-7
lines changed

2 files changed

+166
-7
lines changed

middleware/proxy.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package middleware
55

66
import (
77
"context"
8+
"crypto/tls"
89
"fmt"
910
"io"
1011
"math/rand"
@@ -130,7 +131,7 @@ var DefaultProxyConfig = ProxyConfig{
130131
ContextKey: "target",
131132
}
132133

133-
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
134+
func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
134135
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135136
in, _, err := c.Response().Hijack()
136137
if err != nil {
@@ -139,12 +140,33 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
139140
}
140141
defer in.Close()
141142

142-
out, err := net.Dial("tcp", t.URL.Host)
143-
if err != nil {
144-
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
145-
return
143+
var out net.Conn
144+
if c.IsTLS() {
145+
transport, ok := config.Transport.(*http.Transport)
146+
if !ok {
147+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, invalid transport type"))
148+
return
149+
}
150+
151+
if transport.TLSClientConfig == nil {
152+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, "proxy raw, TLSClientConfig is not set"))
153+
return
154+
}
155+
156+
out, err = tls.Dial("tcp", t.URL.Host, transport.TLSClientConfig)
157+
if err != nil {
158+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
159+
return
160+
}
161+
defer out.Close()
162+
} else {
163+
out, err = net.Dial("tcp", t.URL.Host)
164+
if err != nil {
165+
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
166+
return
167+
}
168+
defer out.Close()
146169
}
147-
defer out.Close()
148170

149171
// Write header
150172
err = r.Write(out)
@@ -365,7 +387,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
365387
// Proxy
366388
switch {
367389
case c.IsWebSocket():
368-
proxyRaw(tgt, c).ServeHTTP(res, req)
390+
proxyRaw(tgt, c, config).ServeHTTP(res, req)
369391
default: // even SSE requests
370392
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
371393
}

middleware/proxy_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package middleware
66
import (
77
"bytes"
88
"context"
9+
"crypto/tls"
910
"errors"
1011
"fmt"
1112
"io"
@@ -20,6 +21,7 @@ import (
2021

2122
"github.com/labstack/echo/v4"
2223
"github.com/stretchr/testify/assert"
24+
"golang.org/x/net/websocket"
2325
)
2426

2527
// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,138 @@ func TestModifyResponseUseContext(t *testing.T) {
810812
assert.Equal(t, "OK", rec.Body.String())
811813
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
812814
}
815+
816+
func TestProxyWithConfigWebSocketTCP(t *testing.T) {
817+
/*
818+
Arrange
819+
*/
820+
e := echo.New()
821+
822+
// Create a WebSocket test server
823+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
824+
wsHandler := func(conn *websocket.Conn) {
825+
defer conn.Close()
826+
for {
827+
var msg string
828+
err := websocket.Message.Receive(conn, &msg)
829+
if err != nil {
830+
return
831+
}
832+
// message back to the client
833+
websocket.Message.Send(conn, msg)
834+
}
835+
}
836+
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
837+
}))
838+
defer srv.Close()
839+
840+
tgtURL, _ := url.Parse(srv.URL)
841+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
842+
843+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
844+
845+
ts := httptest.NewServer(e)
846+
defer ts.Close()
847+
848+
tsURL, _ := url.Parse(ts.URL)
849+
tsURL.Scheme = "ws"
850+
tsURL.Path = "/"
851+
852+
/*
853+
Act
854+
*/
855+
856+
// Connect to the proxy WebSocket
857+
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
858+
assert.NoError(t, err)
859+
defer wsConn.Close()
860+
861+
// Send message
862+
sendMsg := "Hello, WebSocket!"
863+
err = websocket.Message.Send(wsConn, sendMsg)
864+
assert.NoError(t, err)
865+
866+
/*
867+
Assert
868+
*/
869+
// Read response
870+
var recvMsg string
871+
err = websocket.Message.Receive(wsConn, &recvMsg)
872+
assert.NoError(t, err)
873+
assert.Equal(t, sendMsg, recvMsg)
874+
}
875+
876+
func TestProxyWithConfigWebSocketTLS(t *testing.T) {
877+
/*
878+
Arrange
879+
*/
880+
e := echo.New()
881+
882+
// Create a WebSocket test server
883+
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
884+
wsHandler := func(conn *websocket.Conn) {
885+
defer conn.Close()
886+
for {
887+
var msg string
888+
err := websocket.Message.Receive(conn, &msg)
889+
if err != nil {
890+
return
891+
}
892+
// message back to the client
893+
websocket.Message.Send(conn, msg)
894+
}
895+
}
896+
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
897+
}))
898+
defer srv.Close()
899+
900+
// create proxy server
901+
tgtURL, _ := url.Parse(srv.URL)
902+
tgtURL.Scheme = "wss"
903+
904+
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
905+
906+
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
907+
if !ok {
908+
t.Fatal("Default transport is not of type *http.Transport")
909+
}
910+
transport := defaultTransport.Clone()
911+
transport.TLSClientConfig = &tls.Config{
912+
InsecureSkipVerify: true,
913+
}
914+
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
915+
916+
// Start test server
917+
ts := httptest.NewTLSServer(e)
918+
defer ts.Close()
919+
920+
tsURL, _ := url.Parse(ts.URL)
921+
tsURL.Scheme = "wss"
922+
tsURL.Path = "/"
923+
924+
/*
925+
Act
926+
*/
927+
origin, err := url.Parse(ts.URL)
928+
assert.NoError(t, err)
929+
config := &websocket.Config{
930+
Location: tsURL,
931+
Origin: origin,
932+
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
933+
Version: websocket.ProtocolVersionHybi13,
934+
}
935+
wsConn, err := websocket.DialConfig(config)
936+
assert.NoError(t, err)
937+
defer wsConn.Close()
938+
939+
// Send message
940+
sendMsg := "Hello, TLS WebSocket!"
941+
err = websocket.Message.Send(wsConn, sendMsg)
942+
assert.NoError(t, err)
943+
944+
// Read response
945+
var recvMsg string
946+
err = websocket.Message.Receive(wsConn, &recvMsg)
947+
assert.NoError(t, err)
948+
assert.Equal(t, sendMsg, recvMsg)
949+
}

0 commit comments

Comments
 (0)