Skip to content

Commit c6921ab

Browse files
committed
rebase websockets proxy
1 parent ba328a3 commit c6921ab

File tree

1 file changed

+84
-9
lines changed

1 file changed

+84
-9
lines changed

pkg/cloud/websites/websites.go

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ import (
3131
"sync"
3232

3333
"github.com/asaskevich/EventBus"
34+
"github.com/gorilla/websocket"
3435

3536
"github.com/nitrictech/cli/pkg/netx"
37+
"github.com/nitrictech/cli/pkg/system"
3638
deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3739
)
3840

@@ -55,10 +57,11 @@ type (
5557
)
5658

5759
type LocalWebsiteService struct {
58-
websiteRegLock sync.RWMutex
59-
state State
60-
getApiAddress GetApiAddress
61-
isStartCmd bool
60+
websiteRegLock sync.RWMutex
61+
state State
62+
getApiAddress GetApiAddress
63+
getWebsocketAddress GetApiAddress
64+
isStartCmd bool
6265

6366
bus EventBus.Bus
6467
}
@@ -74,6 +77,22 @@ func (l *LocalWebsiteService) SubscribeToState(fn func(State)) {
7477
_ = l.bus.Subscribe(localWebsitesTopic, fn)
7578
}
7679

80+
func proxyWebSocketMessages(src, dst *websocket.Conn, errChan chan error) {
81+
for {
82+
messageType, message, err := src.ReadMessage()
83+
if err != nil {
84+
errChan <- err
85+
return
86+
}
87+
88+
err = dst.WriteMessage(messageType, message)
89+
if err != nil {
90+
errChan <- err
91+
return
92+
}
93+
}
94+
}
95+
7796
// register - Register a new website
7897
func (l *LocalWebsiteService) register(website Website, port int) {
7998
l.websiteRegLock.Lock()
@@ -182,6 +201,55 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
182201
h.serveStatic(res, req)
183202
}
184203

204+
// createWebsocketPathHandler creates a handler for WebSocket proxy requests
205+
func (l *LocalWebsiteService) createWebsocketPathHandler(w http.ResponseWriter, r *http.Request) {
206+
// Get the WebSocket API name from the request path
207+
apiName := r.PathValue("name")
208+
209+
// Get the address of the WebSocket API
210+
apiAddress := l.getWebsocketAddress(apiName)
211+
if apiAddress == "" {
212+
http.Error(w, fmt.Sprintf("WebSocket API %s not found", apiName), http.StatusNotFound)
213+
return
214+
}
215+
216+
// Dial the backend WebSocket server
217+
targetURL := fmt.Sprintf("ws://%s%s", apiAddress, r.URL.Path)
218+
if r.URL.RawQuery != "" {
219+
targetURL = fmt.Sprintf("%s?%s", targetURL, r.URL.RawQuery)
220+
}
221+
222+
targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, nil)
223+
if err != nil {
224+
http.Error(w, fmt.Sprintf("Failed to connect to backend WebSocket server: %v", err), http.StatusInternalServerError)
225+
return
226+
}
227+
defer targetConn.Close()
228+
229+
// Upgrade the HTTP connection to a WebSocket connection
230+
upgrader := websocket.Upgrader{}
231+
232+
clientConn, err := upgrader.Upgrade(w, r, nil)
233+
if err != nil {
234+
http.Error(w, fmt.Sprintf("Failed to upgrade to WebSocket: %v", err), http.StatusInternalServerError)
235+
return
236+
}
237+
238+
defer clientConn.Close()
239+
240+
// Proxy messages between the client and the backend WebSocket server
241+
errChan := make(chan error, 2)
242+
go proxyWebSocketMessages(clientConn, targetConn, errChan)
243+
go proxyWebSocketMessages(targetConn, clientConn, errChan)
244+
245+
// Wait for an error to occur
246+
err = <-errChan
247+
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
248+
// Because the error is already proxied through by the connection we can just log the error here
249+
system.Logf("received error on websocket %s: %v", apiName, err)
250+
}
251+
}
252+
185253
// createAPIPathHandler creates a handler for API proxy requests
186254
func (l *LocalWebsiteService) createAPIPathHandler() http.HandlerFunc {
187255
return func(res http.ResponseWriter, req *http.Request) {
@@ -252,6 +320,9 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
252320
// Register the API proxy handler for this website
253321
mux.HandleFunc("/api/{name}/", l.createAPIPathHandler())
254322

323+
// Register the WebSocket proxy handler for this website
324+
mux.HandleFunc("/ws/{name}", l.createWebsocketPathHandler)
325+
255326
// Create the SPA handler for this website
256327
spa := staticSiteHandler{
257328
website: website,
@@ -289,6 +360,9 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
289360
// Register the API proxy handler
290361
mux.HandleFunc("/api/{name}/", l.createAPIPathHandler())
291362

363+
// Register the WebSocket proxy handler for this website
364+
mux.HandleFunc("/ws/{name}", l.createWebsocketPathHandler)
365+
292366
// Register the SPA handler for each website
293367
for i := range websites {
294368
website := &websites[i]
@@ -325,11 +399,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
325399
return nil
326400
}
327401

328-
func NewLocalWebsitesService(getApiAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
402+
func NewLocalWebsitesService(getApiAddress GetApiAddress, getWebsocketAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
329403
return &LocalWebsiteService{
330-
state: State{},
331-
bus: EventBus.New(),
332-
getApiAddress: getApiAddress,
333-
isStartCmd: isStartCmd,
404+
state: State{},
405+
bus: EventBus.New(),
406+
getApiAddress: getApiAddress,
407+
getWebsocketAddress: getWebsocketAddress,
408+
isStartCmd: isStartCmd,
334409
}
335410
}

0 commit comments

Comments
 (0)