Skip to content

Commit 55f1ab3

Browse files
committed
Add local websocket proxy.
Add local websocket proxy fmt update websocket error logging also stop sending http.Error on upgraded connection. add rewrite ws connection to arch diagram format
1 parent 8811095 commit 55f1ab3

File tree

4 files changed

+117
-18
lines changed

4 files changed

+117
-18
lines changed

pkg/cloud/cloud.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func New(projectName string, opts LocalCloudOptions) (*LocalCloud, error) {
321321
return nil, err
322322
}
323323

324-
localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, opts.LocalCloudMode == StartMode)
324+
localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, localGateway.GetWebsocketAddress, opts.LocalCloudMode == StartMode)
325325

326326
return &LocalCloud{
327327
servers: make(map[string]*server.NitricServer),

pkg/cloud/gateway/gateway.go

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import (
4646
"github.com/nitrictech/cli/pkg/netx"
4747
"github.com/nitrictech/cli/pkg/project/localconfig"
4848
"github.com/nitrictech/cli/pkg/system"
49-
"github.com/nitrictech/cli/pkg/view/tui"
5049

5150
base_http "github.com/nitrictech/nitric/cloud/common/runtime/gateway"
5251

@@ -157,6 +156,19 @@ func (s *LocalGatewayService) GetApiAddress(apiName string) string {
157156
return ""
158157
}
159158

159+
func (s *LocalGatewayService) GetWebsocketAddress(socketName string) string {
160+
s.lock.RLock()
161+
defer s.lock.RUnlock()
162+
163+
addresses := s.GetWebsocketAddresses()
164+
165+
if address, ok := addresses[socketName]; ok {
166+
return address
167+
}
168+
169+
return ""
170+
}
171+
160172
func (s *LocalGatewayService) GetHttpWorkerAddresses() map[string]string {
161173
s.lock.RLock()
162174
defer s.lock.RUnlock()
@@ -349,14 +361,14 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
349361
SocketName: socketName,
350362
})
351363
if err != nil {
352-
tui.Error.Println(err.Error())
364+
system.Logf("Websocket error: %s", err.Error())
353365
return
354366
}
355367
}()
356368

357369
err = s.websocketPlugin.RegisterConnection(socketName, connectionId, ws)
358370
if err != nil {
359-
tui.Error.Println(err.Error())
371+
system.Logf("Websocket error: %s", err.Error())
360372
return
361373
}
362374

@@ -372,7 +384,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
372384
if err != nil && websocket.IsCloseError(err, 1001, 1005) {
373385
break
374386
} else if err != nil {
375-
log.Println("read:", err)
387+
system.Logf("websocket read error: %v", err)
376388
break
377389
}
378390

@@ -390,7 +402,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
390402
},
391403
})
392404
if err != nil {
393-
tui.Error.Println(err.Error())
405+
system.Logf("Websocket error: %s", err.Error())
394406
return
395407
}
396408
}
@@ -407,13 +419,13 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
407419
},
408420
})
409421
if err != nil {
410-
tui.Error.Println(err.Error())
422+
system.Logf("Websocket error: %s", err.Error())
411423
return
412424
}
413425
})
414426
if err != nil {
415427
if _, ok := err.(websocket.HandshakeError); ok {
416-
tui.Error.Println(err.Error())
428+
system.Logf("Websocket error: %s", err.Error())
417429
}
418430

419431
return

pkg/cloud/websites/websites.go

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ import (
3030
"sync"
3131

3232
"github.com/asaskevich/EventBus"
33+
"github.com/gorilla/websocket"
3334

3435
"github.com/nitrictech/cli/pkg/netx"
36+
"github.com/nitrictech/cli/pkg/system"
3537
deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3638
)
3739

@@ -54,11 +56,12 @@ type (
5456
)
5557

5658
type LocalWebsiteService struct {
57-
websiteRegLock sync.RWMutex
58-
state State
59-
port int
60-
getApiAddress GetApiAddress
61-
isStartCmd bool
59+
websiteRegLock sync.RWMutex
60+
state State
61+
port int
62+
getApiAddress GetApiAddress
63+
getWebsocketAddress GetApiAddress
64+
isStartCmd bool
6265

6366
bus EventBus.Bus
6467
}
@@ -172,6 +175,23 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
172175
}
173176

174177
// Start - Start the local website service
178+
func proxyWebSocketMessages(src, dst *websocket.Conn, errChan chan error) {
179+
for {
180+
messageType, message, err := src.ReadMessage()
181+
if err != nil {
182+
errChan <- err
183+
return
184+
}
185+
186+
err = dst.WriteMessage(messageType, message)
187+
if err != nil {
188+
errChan <- err
189+
return
190+
}
191+
}
192+
}
193+
194+
// Serve - Serve a website from the local filesystem
175195
func (l *LocalWebsiteService) Start(websites []Website) error {
176196
newLis, err := netx.GetNextListener(netx.MinPort(5000))
177197
if err != nil {
@@ -203,6 +223,55 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
203223
proxy.ServeHTTP(res, req)
204224
})
205225

226+
// Register the API handler
227+
mux.HandleFunc("/ws/{name}", func(w http.ResponseWriter, r *http.Request) {
228+
// Get the WebSocket API name from the request path
229+
apiName := r.PathValue("name")
230+
231+
// Get the address of the WebSocket API
232+
apiAddress := l.getWebsocketAddress(apiName)
233+
if apiAddress == "" {
234+
http.Error(w, fmt.Sprintf("WebSocket API %s not found", apiName), http.StatusNotFound)
235+
return
236+
}
237+
238+
// Dial the backend WebSocket server
239+
targetURL := fmt.Sprintf("ws://%s%s", apiAddress, r.URL.Path)
240+
if r.URL.RawQuery != "" {
241+
targetURL = fmt.Sprintf("%s?%s", targetURL, r.URL.RawQuery)
242+
}
243+
244+
targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, nil)
245+
if err != nil {
246+
http.Error(w, fmt.Sprintf("Failed to connect to backend WebSocket server: %v", err), http.StatusInternalServerError)
247+
return
248+
}
249+
defer targetConn.Close()
250+
251+
// Upgrade the HTTP connection to a WebSocket connection
252+
upgrader := websocket.Upgrader{}
253+
254+
clientConn, err := upgrader.Upgrade(w, r, nil)
255+
if err != nil {
256+
http.Error(w, fmt.Sprintf("Failed to upgrade to WebSocket: %v", err), http.StatusInternalServerError)
257+
return
258+
}
259+
260+
defer clientConn.Close()
261+
262+
// Proxy messages between the client and the backend WebSocket server
263+
errChan := make(chan error, 2)
264+
go proxyWebSocketMessages(clientConn, targetConn, errChan)
265+
go proxyWebSocketMessages(targetConn, clientConn, errChan)
266+
267+
// Wait for an error to occur
268+
err = <-errChan
269+
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
270+
// Because the error is already proxied through by the connection we can just log the error here
271+
system.Logf("received error on websocket %s: %v", apiName, err)
272+
}
273+
})
274+
206275
// Register the SPA handler for each website
207276
for i := range websites {
208277
website := &websites[i]
@@ -231,11 +300,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
231300
return nil
232301
}
233302

234-
func NewLocalWebsitesService(getApiAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
303+
func NewLocalWebsitesService(getApiAddress GetApiAddress, getWebsocketAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
235304
return &LocalWebsiteService{
236-
state: State{},
237-
bus: EventBus.New(),
238-
getApiAddress: getApiAddress,
239-
isStartCmd: isStartCmd,
305+
state: State{},
306+
bus: EventBus.New(),
307+
getApiAddress: getApiAddress,
308+
getWebsocketAddress: getWebsocketAddress,
309+
isStartCmd: isStartCmd,
240310
}
241311
}

pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,23 @@ export function generateArchitectureData(data: WebSocketResponse): {
609609
label: `Rewrites to /api/${api.name}`,
610610
})
611611
})
612+
613+
data.websockets.forEach((websocket) => {
614+
edges.push({
615+
id: `e-${websocket.name}-websites`,
616+
source: websitesNode.id,
617+
target: `websocket-${websocket.name}`,
618+
animated: true,
619+
markerEnd: {
620+
type: MarkerType.ArrowClosed,
621+
},
622+
markerStart: {
623+
type: MarkerType.ArrowClosed,
624+
orient: 'auto-start-reverse',
625+
},
626+
label: `Rewrites to /ws/${websocket.name}`,
627+
})
628+
})
612629
}
613630

614631
data.services.forEach((service) => {

0 commit comments

Comments
 (0)