@@ -30,8 +30,10 @@ import (
3030 "sync"
3131
3232 "github.com/asaskevich/EventBus"
33+ "github.com/gorilla/websocket"
3334
3435 "github.com/nitrictech/cli/pkg/netx"
36+ "github.com/nitrictech/cli/pkg/system"
3537 deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3638)
3739
@@ -54,11 +56,12 @@ type (
5456)
5557
5658type LocalWebsiteService struct {
57- websiteRegLock sync.RWMutex
58- state State
59- port int
60- getApiAddress GetApiAddress
61- isStartCmd bool
59+ websiteRegLock sync.RWMutex
60+ state State
61+ port int
62+ getApiAddress GetApiAddress
63+ getWebsocketAddress GetApiAddress
64+ isStartCmd bool
6265
6366 bus EventBus.Bus
6467}
@@ -172,6 +175,23 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
172175}
173176
174177// Start - Start the local website service
178+ func proxyWebSocketMessages (src , dst * websocket.Conn , errChan chan error ) {
179+ for {
180+ messageType , message , err := src .ReadMessage ()
181+ if err != nil {
182+ errChan <- err
183+ return
184+ }
185+
186+ err = dst .WriteMessage (messageType , message )
187+ if err != nil {
188+ errChan <- err
189+ return
190+ }
191+ }
192+ }
193+
194+ // Serve - Serve a website from the local filesystem
175195func (l * LocalWebsiteService ) Start (websites []Website ) error {
176196 newLis , err := netx .GetNextListener (netx .MinPort (5000 ))
177197 if err != nil {
@@ -203,6 +223,55 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
203223 proxy .ServeHTTP (res , req )
204224 })
205225
226+ // Register the API handler
227+ mux .HandleFunc ("/ws/{name}" , func (w http.ResponseWriter , r * http.Request ) {
228+ // Get the WebSocket API name from the request path
229+ apiName := r .PathValue ("name" )
230+
231+ // Get the address of the WebSocket API
232+ apiAddress := l .getWebsocketAddress (apiName )
233+ if apiAddress == "" {
234+ http .Error (w , fmt .Sprintf ("WebSocket API %s not found" , apiName ), http .StatusNotFound )
235+ return
236+ }
237+
238+ // Dial the backend WebSocket server
239+ targetURL := fmt .Sprintf ("ws://%s%s" , apiAddress , r .URL .Path )
240+ if r .URL .RawQuery != "" {
241+ targetURL = fmt .Sprintf ("%s?%s" , targetURL , r .URL .RawQuery )
242+ }
243+
244+ targetConn , _ , err := websocket .DefaultDialer .Dial (targetURL , nil )
245+ if err != nil {
246+ http .Error (w , fmt .Sprintf ("Failed to connect to backend WebSocket server: %v" , err ), http .StatusInternalServerError )
247+ return
248+ }
249+ defer targetConn .Close ()
250+
251+ // Upgrade the HTTP connection to a WebSocket connection
252+ upgrader := websocket.Upgrader {}
253+
254+ clientConn , err := upgrader .Upgrade (w , r , nil )
255+ if err != nil {
256+ http .Error (w , fmt .Sprintf ("Failed to upgrade to WebSocket: %v" , err ), http .StatusInternalServerError )
257+ return
258+ }
259+
260+ defer clientConn .Close ()
261+
262+ // Proxy messages between the client and the backend WebSocket server
263+ errChan := make (chan error , 2 )
264+ go proxyWebSocketMessages (clientConn , targetConn , errChan )
265+ go proxyWebSocketMessages (targetConn , clientConn , errChan )
266+
267+ // Wait for an error to occur
268+ err = <- errChan
269+ if err != nil && ! errors .Is (err , websocket .ErrCloseSent ) {
270+ // Because the error is already proxied through by the connection we can just log the error here
271+ system .Logf ("received error on websocket %s: %v" , apiName , err )
272+ }
273+ })
274+
206275 // Register the SPA handler for each website
207276 for i := range websites {
208277 website := & websites [i ]
@@ -231,11 +300,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
231300 return nil
232301}
233302
234- func NewLocalWebsitesService (getApiAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
303+ func NewLocalWebsitesService (getApiAddress GetApiAddress , getWebsocketAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
235304 return & LocalWebsiteService {
236- state : State {},
237- bus : EventBus .New (),
238- getApiAddress : getApiAddress ,
239- isStartCmd : isStartCmd ,
305+ state : State {},
306+ bus : EventBus .New (),
307+ getApiAddress : getApiAddress ,
308+ getWebsocketAddress : getWebsocketAddress ,
309+ isStartCmd : isStartCmd ,
240310 }
241311}
0 commit comments