@@ -2,14 +2,17 @@ package router
22
33import (
44 "context"
5+ "encoding/json"
6+ "net/http"
57 "time"
68
9+ "emperror.dev/errors"
710 "github.com/gin-gonic/gin"
8- "github.com/goccy/go-json"
911 ws "github.com/gorilla/websocket"
10-
1112 "github.com/pelican-dev/wings/router/middleware"
1213 "github.com/pelican-dev/wings/router/websocket"
14+ "github.com/pelican-dev/wings/server"
15+ "golang.org/x/time/rate"
1316)
1417
1518var expectedCloseCodes = []int {
@@ -25,6 +28,27 @@ func getServerWebsocket(c *gin.Context) {
2528 manager := middleware .ExtractManager (c )
2629 s , _ := manager .Get (c .Param ("server" ))
2730
31+ // Limit the total number of websockets that can be opened at any one time for
32+ // a server instance. This applies across all users connected to the server, and
33+ // is not applied on a per-user basis.
34+ //
35+ // todo: it would be great to make this per-user instead, but we need to modify
36+ // how we even request this endpoint in order for that to be possible. Some type
37+ // of signed identifier in the URL that is verified on this end and set by the
38+ // panel using a shared secret is likely the easiest option. The benefit of that
39+ // is that we can both scope things to the user before authentication, and also
40+ // verify that the JWT provided by the panel is assigned to the same user.
41+ if s .Websockets ().Len () >= 30 {
42+ c .AbortWithStatusJSON (http .StatusBadRequest , gin.H {
43+ "error" : "Too many open websocket connections." ,
44+ })
45+
46+ return
47+ }
48+
49+ c .Header ("Content-Security-Policy" , "default-src 'self'" )
50+ c .Header ("X-Frame-Options" , "DENY" )
51+
2852 // Create a context that can be canceled when the user disconnects from this
2953 // socket that will also cancel listeners running in separate threads. If the
3054 // connection itself is terminated listeners using this context will also be
@@ -37,53 +61,101 @@ func getServerWebsocket(c *gin.Context) {
3761 middleware .CaptureAndAbort (c , err )
3862 return
3963 }
40- defer handler .Connection .Close ()
4164
4265 // Track this open connection on the server so that we can close them all programmatically
4366 // if the server is deleted.
4467 s .Websockets ().Push (handler .Uuid (), & cancel )
4568 handler .Logger ().Debug ("opening connection to server websocket" )
69+ defer s .Websockets ().Remove (handler .Uuid ())
4670
47- defer func () {
48- s .Websockets ().Remove (handler .Uuid ())
49- handler .Logger ().Debug ("closing connection to server websocket" )
71+ go func () {
72+ select {
73+ // When the main context is canceled (through disconnect, server deletion, or server
74+ // suspension) close the connection itself.
75+ case <- ctx .Done ():
76+ handler .Logger ().Debug ("closing connection to server websocket" )
77+ if err := handler .Connection .Close (); err != nil {
78+ handler .Logger ().WithError (err ).Error ("failed to close websocket connection" )
79+ }
80+ break
81+ }
5082 }()
5183
52- // If the server is deleted we need to send a close message to the connected client
53- // so that they disconnect since there will be no more events sent along. Listen for
54- // the request context being closed to break this loop, otherwise this routine will
55- // be left hanging in the background.
5684 go func () {
5785 select {
5886 case <- ctx .Done ():
59- break
87+ return
88+ // If the server is deleted we need to send a close message to the connected client
89+ // so that they disconnect since there will be no more events sent along. Listen for
90+ // the request context being closed to break this loop, otherwise this routine will
91+ //be left hanging in the background.
6092 case <- s .Context ().Done ():
61- _ = handler . Connection . WriteControl ( ws . CloseMessage , ws . FormatCloseMessage ( ws . CloseGoingAway , "server deleted" ), time . Now (). Add ( time . Second * 5 ) )
93+ cancel ( )
6294 break
6395 }
6496 }()
6597
66- for {
67- j := websocket.Message {}
98+ // Due to how websockets are handled we need to connect to the socket
99+ // and _then_ abort it if the server is suspended. You cannot capture
100+ // the HTTP response in the websocket client, thus we connect and then
101+ // immediately close with failure.
102+ if s .IsSuspended () {
103+ _ = handler .Connection .WriteMessage (ws .CloseMessage , ws .FormatCloseMessage (4409 , "server is suspended" ))
68104
69- _ , p , err := handler .Connection .ReadMessage ()
105+ return
106+ }
107+
108+ // There is a separate rate limiter that applies to individual message types
109+ // within the actual websocket logic handler. _This_ rate limiter just exists
110+ // to avoid enormous floods of data through the socket since we need to parse
111+ // JSON each time. This rate limit realistically should never be hit since this
112+ // would require sending 50+ messages a second over the websocket (no more than
113+ // 10 per 200ms).
114+ var throttled bool
115+ rl := rate .NewLimiter (rate .Every (time .Millisecond * 200 ), 10 )
116+
117+ for {
118+ t , p , err := handler .Connection .ReadMessage ()
70119 if err != nil {
71120 if ws .IsUnexpectedCloseError (err , expectedCloseCodes ... ) {
72121 handler .Logger ().WithField ("error" , err ).Warn ("error handling websocket message for server" )
73122 }
74123 break
75124 }
76125
126+ if ! rl .Allow () {
127+ if ! throttled {
128+ throttled = true
129+ _ = handler .Connection .WriteJSON (websocket.Message {Event : websocket .ThrottledEvent , Args : []string {"global" }})
130+ }
131+ continue
132+ }
133+
134+ throttled = false
135+
136+ // If the message isn't a format we expect, or the length of the message is far larger
137+ // than we'd ever expect, drop it. The websocket upgrader logic does enforce a maximum
138+ // _compressed_ message size of 4Kb but that could decompress to a much larger amount
139+ // of data.
140+ if t != ws .TextMessage || len (p ) > 32_768 {
141+ continue
142+ }
143+
77144 // Discard and JSON parse errors into the void and don't continue processing this
78145 // specific socket request. If we did a break here the client would get disconnected
79146 // from the socket, which is NOT what we want to do.
147+ var j websocket.Message
80148 if err := json .Unmarshal (p , & j ); err != nil {
81149 continue
82150 }
83151
84152 go func (msg websocket.Message ) {
85153 if err := handler .HandleInbound (ctx , msg ); err != nil {
86- _ = handler .SendErrorJson (msg , err )
154+ if errors .Is (err , server .ErrSuspended ) {
155+ cancel ()
156+ } else {
157+ _ = handler .SendErrorJson (msg , err )
158+ }
87159 }
88160 }(j )
89161 }
0 commit comments