@@ -6,6 +6,7 @@ package middleware
66import (
77 "bytes"
88 "context"
9+ "crypto/tls"
910 "errors"
1011 "fmt"
1112 "io"
@@ -20,6 +21,7 @@ import (
2021
2122 "github.com/labstack/echo/v4"
2223 "github.com/stretchr/testify/assert"
24+ "golang.org/x/net/websocket"
2325)
2426
2527// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,138 @@ func TestModifyResponseUseContext(t *testing.T) {
810812 assert .Equal (t , "OK" , rec .Body .String ())
811813 assert .Equal (t , "CUSTOM_BALANCER" , rec .Header ().Get ("FROM_BALANCER" ))
812814}
815+
816+ func TestProxyWithConfigWebSocketTCP (t * testing.T ) {
817+ /*
818+ Arrange
819+ */
820+ e := echo .New ()
821+
822+ // Create a WebSocket test server
823+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
824+ wsHandler := func (conn * websocket.Conn ) {
825+ defer conn .Close ()
826+ for {
827+ var msg string
828+ err := websocket .Message .Receive (conn , & msg )
829+ if err != nil {
830+ return
831+ }
832+ // message back to the client
833+ websocket .Message .Send (conn , msg )
834+ }
835+ }
836+ websocket.Server {Handler : wsHandler }.ServeHTTP (w , r )
837+ }))
838+ defer srv .Close ()
839+
840+ tgtURL , _ := url .Parse (srv .URL )
841+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
842+
843+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer }))
844+
845+ ts := httptest .NewServer (e )
846+ defer ts .Close ()
847+
848+ tsURL , _ := url .Parse (ts .URL )
849+ tsURL .Scheme = "ws"
850+ tsURL .Path = "/"
851+
852+ /*
853+ Act
854+ */
855+
856+ // Connect to the proxy WebSocket
857+ wsConn , err := websocket .Dial (tsURL .String (), "" , "http://localhost/" )
858+ assert .NoError (t , err )
859+ defer wsConn .Close ()
860+
861+ // Send message
862+ sendMsg := "Hello, WebSocket!"
863+ err = websocket .Message .Send (wsConn , sendMsg )
864+ assert .NoError (t , err )
865+
866+ /*
867+ Assert
868+ */
869+ // Read response
870+ var recvMsg string
871+ err = websocket .Message .Receive (wsConn , & recvMsg )
872+ assert .NoError (t , err )
873+ assert .Equal (t , sendMsg , recvMsg )
874+ }
875+
876+ func TestProxyWithConfigWebSocketTLS (t * testing.T ) {
877+ /*
878+ Arrange
879+ */
880+ e := echo .New ()
881+
882+ // Create a WebSocket test server
883+ srv := httptest .NewTLSServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
884+ wsHandler := func (conn * websocket.Conn ) {
885+ defer conn .Close ()
886+ for {
887+ var msg string
888+ err := websocket .Message .Receive (conn , & msg )
889+ if err != nil {
890+ return
891+ }
892+ // message back to the client
893+ websocket .Message .Send (conn , msg )
894+ }
895+ }
896+ websocket.Server {Handler : wsHandler }.ServeHTTP (w , r )
897+ }))
898+ defer srv .Close ()
899+
900+ // create proxy server
901+ tgtURL , _ := url .Parse (srv .URL )
902+ tgtURL .Scheme = "wss"
903+
904+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
905+
906+ defaultTransport , ok := http .DefaultTransport .(* http.Transport )
907+ if ! ok {
908+ t .Fatal ("Default transport is not of type *http.Transport" )
909+ }
910+ transport := defaultTransport .Clone ()
911+ transport .TLSClientConfig = & tls.Config {
912+ InsecureSkipVerify : true ,
913+ }
914+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer , Transport : transport }))
915+
916+ // Start test server
917+ ts := httptest .NewTLSServer (e )
918+ defer ts .Close ()
919+
920+ tsURL , _ := url .Parse (ts .URL )
921+ tsURL .Scheme = "wss"
922+ tsURL .Path = "/"
923+
924+ /*
925+ Act
926+ */
927+ origin , err := url .Parse (ts .URL )
928+ assert .NoError (t , err )
929+ config := & websocket.Config {
930+ Location : tsURL ,
931+ Origin : origin ,
932+ TlsConfig : & tls.Config {InsecureSkipVerify : true }, // skip verify for testing
933+ Version : websocket .ProtocolVersionHybi13 ,
934+ }
935+ wsConn , err := websocket .DialConfig (config )
936+ assert .NoError (t , err )
937+ defer wsConn .Close ()
938+
939+ // Send message
940+ sendMsg := "Hello, TLS WebSocket!"
941+ err = websocket .Message .Send (wsConn , sendMsg )
942+ assert .NoError (t , err )
943+
944+ // Read response
945+ var recvMsg string
946+ err = websocket .Message .Receive (wsConn , & recvMsg )
947+ assert .NoError (t , err )
948+ assert .Equal (t , sendMsg , recvMsg )
949+ }
0 commit comments