@@ -31,8 +31,10 @@ import (
3131 "sync"
3232
3333 "github.com/asaskevich/EventBus"
34+ "github.com/gorilla/websocket"
3435
3536 "github.com/nitrictech/cli/pkg/netx"
37+ "github.com/nitrictech/cli/pkg/system"
3638 deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3739)
3840
@@ -55,10 +57,11 @@ type (
5557)
5658
5759type LocalWebsiteService struct {
58- websiteRegLock sync.RWMutex
59- state State
60- getApiAddress GetApiAddress
61- isStartCmd bool
60+ websiteRegLock sync.RWMutex
61+ state State
62+ getApiAddress GetApiAddress
63+ getWebsocketAddress GetApiAddress
64+ isStartCmd bool
6265
6366 bus EventBus.Bus
6467}
@@ -74,6 +77,22 @@ func (l *LocalWebsiteService) SubscribeToState(fn func(State)) {
7477 _ = l .bus .Subscribe (localWebsitesTopic , fn )
7578}
7679
80+ func proxyWebSocketMessages (src , dst * websocket.Conn , errChan chan error ) {
81+ for {
82+ messageType , message , err := src .ReadMessage ()
83+ if err != nil {
84+ errChan <- err
85+ return
86+ }
87+
88+ err = dst .WriteMessage (messageType , message )
89+ if err != nil {
90+ errChan <- err
91+ return
92+ }
93+ }
94+ }
95+
7796// register - Register a new website
7897func (l * LocalWebsiteService ) register (website Website , port int ) {
7998 l .websiteRegLock .Lock ()
@@ -182,6 +201,55 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
182201 h .serveStatic (res , req )
183202}
184203
204+ // createWebsocketPathHandler creates a handler for WebSocket proxy requests
205+ func (l * LocalWebsiteService ) createWebsocketPathHandler (w http.ResponseWriter , r * http.Request ) {
206+ // Get the WebSocket API name from the request path
207+ apiName := r .PathValue ("name" )
208+
209+ // Get the address of the WebSocket API
210+ apiAddress := l .getWebsocketAddress (apiName )
211+ if apiAddress == "" {
212+ http .Error (w , fmt .Sprintf ("WebSocket API %s not found" , apiName ), http .StatusNotFound )
213+ return
214+ }
215+
216+ // Dial the backend WebSocket server
217+ targetURL := fmt .Sprintf ("ws://%s%s" , apiAddress , r .URL .Path )
218+ if r .URL .RawQuery != "" {
219+ targetURL = fmt .Sprintf ("%s?%s" , targetURL , r .URL .RawQuery )
220+ }
221+
222+ targetConn , _ , err := websocket .DefaultDialer .Dial (targetURL , nil )
223+ if err != nil {
224+ http .Error (w , fmt .Sprintf ("Failed to connect to backend WebSocket server: %v" , err ), http .StatusInternalServerError )
225+ return
226+ }
227+ defer targetConn .Close ()
228+
229+ // Upgrade the HTTP connection to a WebSocket connection
230+ upgrader := websocket.Upgrader {}
231+
232+ clientConn , err := upgrader .Upgrade (w , r , nil )
233+ if err != nil {
234+ http .Error (w , fmt .Sprintf ("Failed to upgrade to WebSocket: %v" , err ), http .StatusInternalServerError )
235+ return
236+ }
237+
238+ defer clientConn .Close ()
239+
240+ // Proxy messages between the client and the backend WebSocket server
241+ errChan := make (chan error , 2 )
242+ go proxyWebSocketMessages (clientConn , targetConn , errChan )
243+ go proxyWebSocketMessages (targetConn , clientConn , errChan )
244+
245+ // Wait for an error to occur
246+ err = <- errChan
247+ if err != nil && ! errors .Is (err , websocket .ErrCloseSent ) {
248+ // Because the error is already proxied through by the connection we can just log the error here
249+ system .Logf ("received error on websocket %s: %v" , apiName , err )
250+ }
251+ }
252+
185253// createAPIPathHandler creates a handler for API proxy requests
186254func (l * LocalWebsiteService ) createAPIPathHandler () http.HandlerFunc {
187255 return func (res http.ResponseWriter , req * http.Request ) {
@@ -252,6 +320,9 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
252320 // Register the API proxy handler for this website
253321 mux .HandleFunc ("/api/{name}/" , l .createAPIPathHandler ())
254322
323+ // Register the WebSocket proxy handler for this website
324+ mux .HandleFunc ("/ws/{name}" , l .createWebsocketPathHandler )
325+
255326 // Create the SPA handler for this website
256327 spa := staticSiteHandler {
257328 website : website ,
@@ -289,6 +360,9 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
289360 // Register the API proxy handler
290361 mux .HandleFunc ("/api/{name}/" , l .createAPIPathHandler ())
291362
363+ // Register the WebSocket proxy handler for this website
364+ mux .HandleFunc ("/ws/{name}" , l .createWebsocketPathHandler )
365+
292366 // Register the SPA handler for each website
293367 for i := range websites {
294368 website := & websites [i ]
@@ -325,11 +399,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
325399 return nil
326400}
327401
328- func NewLocalWebsitesService (getApiAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
402+ func NewLocalWebsitesService (getApiAddress GetApiAddress , getWebsocketAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
329403 return & LocalWebsiteService {
330- state : State {},
331- bus : EventBus .New (),
332- getApiAddress : getApiAddress ,
333- isStartCmd : isStartCmd ,
404+ state : State {},
405+ bus : EventBus .New (),
406+ getApiAddress : getApiAddress ,
407+ getWebsocketAddress : getWebsocketAddress ,
408+ isStartCmd : isStartCmd ,
334409 }
335410}
0 commit comments