@@ -56,6 +56,7 @@ import (
5656 fakeruntime "github.com/linuxsuren/go-fake-runtime"
5757 "github.com/linuxsuren/oauth-hub"
5858
59+ "github.com/gorilla/websocket"
5960 "github.com/prometheus/client_golang/prometheus"
6061 "github.com/prometheus/client_golang/prometheus/collectors"
6162 "github.com/prometheus/client_golang/prometheus/promhttp"
@@ -379,6 +380,7 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) {
379380 ctx = context .WithValue (ctx , k , v )
380381 }
381382
383+ endpoint := pathParams ["endpoint" ]
382384 resp , err := extServer .GetPageOfServer (ctx , & server.SimpleName {Name : pathParams ["extension" ]})
383385 if err != nil {
384386 fmt .Println (err )
@@ -389,8 +391,18 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) {
389391 return
390392 }
391393
392- fmt .Println ("redirect to" , resp .Message , "method" , r .Method )
393- req , err := http .NewRequestWithContext (ctx , r .Method , resp .Message , r .Body )
394+ api := resp .Message + "/" + endpoint
395+
396+ // Check if this is a WebSocket request
397+ if isWebSocketRequest (r ) {
398+ api = strings .ReplaceAll (api , "http://" , "ws://" )
399+ fmt .Println ("WebSocket request detected" , api )
400+ handleWebSocketProxy (w , r , api )
401+ return
402+ }
403+
404+ fmt .Println ("redirect to" , api , "method" , r .Method )
405+ req , err := http .NewRequestWithContext (ctx , r .Method , api , r .Body )
394406 if err != nil {
395407 fmt .Println (err )
396408 return
@@ -427,9 +439,10 @@ func (o *serverOption) runE(cmd *cobra.Command, args []string) (err error) {
427439 flusher .Flush ()
428440 }
429441 }
430- mux .HandlePath (http .MethodPost , "/extensionProxy/{extension}" , proxyHandler )
431- mux .HandlePath (http .MethodGet , "/extensionProxy/{extension}" , proxyHandler )
432- mux .HandlePath (http .MethodDelete , "/extensionProxy/{extension}" , proxyHandler )
442+ mux .HandlePath (http .MethodPost , "/extensionProxy/{extension}/{endpoint}" , proxyHandler )
443+ mux .HandlePath (http .MethodGet , "/extensionProxy/{extension}/{endpoint}" , proxyHandler )
444+ mux .HandlePath (http .MethodDelete , "/extensionProxy/{extension}/{endpoint}" , proxyHandler )
445+ mux .HandlePath (http .MethodPost , "/extensionProxy/{extension}/{endpoint}" , proxyHandler )
433446 mux .HandlePath (http .MethodGet , "/get" , o .getAtestBinary )
434447 mux .HandlePath (http .MethodPost , "/runner/{suite}/{case}" , service .WebRunnerHandler )
435448 mux .HandlePath (http .MethodGet , "/api/v1/sbom" , service .SBomHandler )
@@ -660,6 +673,84 @@ func (s *fakeGRPCServer) RegisterService(desc *grpc.ServiceDesc, impl interface{
660673 // Do nothing due to this is a fake method
661674}
662675
676+ var upgrader = websocket.Upgrader {
677+ CheckOrigin : func (r * http.Request ) bool {
678+ return true // Allow connections from any origin
679+ },
680+ }
681+
682+ func isWebSocketRequest (r * http.Request ) bool {
683+ return strings .ToLower (r .Header .Get ("Connection" )) == "upgrade" && strings .ToLower (r .Header .Get ("Upgrade" )) == "websocket"
684+ }
685+
686+ func handleWebSocketProxy (w http.ResponseWriter , r * http.Request , targetURL string ) {
687+ // Upgrade the connection
688+ clientConn , err := upgrader .Upgrade (w , r , nil )
689+ if err != nil {
690+ fmt .Println ("Failed to upgrade connection:" , err )
691+ return
692+ }
693+ defer clientConn .Close ()
694+
695+ // Create a WebSocket connection to the target server
696+ // Clone headers to avoid duplicate headers issue
697+ headers := make (http.Header )
698+ for k , v := range r .Header {
699+ // Skip headers that will be set by the WebSocket dialer
700+ if k == "Upgrade" || k == "Connection" || k == "Sec-Websocket-Key" ||
701+ k == "Sec-Websocket-Version" || k == "Sec-Websocket-Extensions" ||
702+ k == "Sec-Websocket-Protocol" {
703+ continue
704+ }
705+ headers [k ] = v
706+ }
707+
708+ targetConn , _ , err := websocket .DefaultDialer .Dial (targetURL , headers )
709+ if err != nil {
710+ fmt .Println ("Failed to connect to target:" , err )
711+ return
712+ }
713+ defer targetConn .Close ()
714+
715+ // Proxy messages between client and target
716+ errChan := make (chan error , 2 )
717+
718+ // Client to target
719+ go func () {
720+ for {
721+ messageType , message , err := clientConn .ReadMessage ()
722+ if err != nil {
723+ errChan <- err
724+ return
725+ }
726+
727+ if err := targetConn .WriteMessage (messageType , message ); err != nil {
728+ errChan <- err
729+ return
730+ }
731+ }
732+ }()
733+
734+ // Target to client
735+ go func () {
736+ for {
737+ messageType , message , err := targetConn .ReadMessage ()
738+ if err != nil {
739+ errChan <- err
740+ return
741+ }
742+
743+ if err := clientConn .WriteMessage (messageType , message ); err != nil {
744+ errChan <- err
745+ return
746+ }
747+ }
748+ }()
749+
750+ // Wait for an error to occur
751+ <- errChan
752+ }
753+
663754//go:embed data/index.js
664755var uiResourceJS []byte
665756
0 commit comments