Skip to content

Commit 67b40fa

Browse files
committed
fix: merge upstream fixes
1 parent ccefaf4 commit 67b40fa

File tree

12 files changed

+250
-45
lines changed

12 files changed

+250
-45
lines changed

src/router/router_server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ func destroyServerBackupRepositories(ctx context.Context, s *server.Server, clie
323323
// Adds any of the JTIs passed through in the body to the deny list for the websocket
324324
// preventing any JWT generated before the current time from being used to connect to
325325
// the socket or send along commands.
326-
// @deprecated superceded by /api/revoke
326+
//
327+
// deprecated: prefer /api/deauthorize-user
327328
func postServerDenyWSTokens(c *gin.Context) {
328329
var data struct {
329330
JTIs []string `json:"jtis"`

src/router/router_server_ws.go

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@ package router
22

33
import (
44
"context"
5+
"emperror.dev/errors"
56
"encoding/json"
7+
8+
"net/http"
69
"time"
710

811
"github.com/gin-gonic/gin"
912
ws "github.com/gorilla/websocket"
1013

1114
"github.com/pyrohost/elytra/src/router/middleware"
1215
"github.com/pyrohost/elytra/src/router/websocket"
16+
17+
"github.com/pyrohost/elytra/src/server"
18+
"golang.org/x/time/rate"
1319
)
1420

1521
var expectedCloseCodes = []int{
@@ -25,6 +31,27 @@ func getServerWebsocket(c *gin.Context) {
2531
manager := middleware.ExtractManager(c)
2632
s, _ := manager.Get(c.Param("server"))
2733

34+
// Limit the total number of websockets that can be opened at any one time for
35+
// a server instance. This applies across all users connected to the server, and
36+
// is not applied on a per-user basis.
37+
//
38+
// todo: it would be great to make this per-user instead, but we need to modify
39+
// how we even request this endpoint in order for that to be possible. Some type
40+
// of signed identifier in the URL that is verified on this end and set by the
41+
// panel using a shared secret is likely the easiest option. The benefit of that
42+
// is that we can both scope things to the user before authentication, and also
43+
// verify that the JWT provided by the panel is assigned to the same user.
44+
if s.Websockets().Len() >= 30 {
45+
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
46+
"error": "Too many open websocket connections.",
47+
})
48+
49+
return
50+
}
51+
52+
c.Header("Content-Security-Policy", "default-src 'self'")
53+
c.Header("X-Frame-Options", "DENY")
54+
2855
// Create a context that can be canceled when the user disconnects from this
2956
// socket that will also cancel listeners running in separate threads. If the
3057
// connection itself is terminated listeners using this context will also be
@@ -37,53 +64,101 @@ func getServerWebsocket(c *gin.Context) {
3764
middleware.CaptureAndAbort(c, err)
3865
return
3966
}
40-
defer handler.Connection.Close()
4167

4268
// Track this open connection on the server so that we can close them all programmatically
4369
// if the server is deleted.
4470
s.Websockets().Push(handler.Uuid(), &cancel)
4571
handler.Logger().Debug("opening connection to server websocket")
72+
defer s.Websockets().Remove(handler.Uuid())
4673

47-
defer func() {
48-
s.Websockets().Remove(handler.Uuid())
49-
handler.Logger().Debug("closing connection to server websocket")
74+
go func() {
75+
select {
76+
// When the main context is canceled (through disconnect, server deletion, or server
77+
// suspension) close the connection itself.
78+
case <-ctx.Done():
79+
handler.Logger().Debug("closing connection to server websocket")
80+
if err := handler.Connection.Close(); err != nil {
81+
handler.Logger().WithError(err).Error("failed to close websocket connection")
82+
}
83+
break
84+
}
5085
}()
5186

52-
// If the server is deleted we need to send a close message to the connected client
53-
// so that they disconnect since there will be no more events sent along. Listen for
54-
// the request context being closed to break this loop, otherwise this routine will
55-
// be left hanging in the background.
5687
go func() {
5788
select {
5889
case <-ctx.Done():
59-
break
90+
return
91+
// If the server is deleted we need to send a close message to the connected client
92+
// so that they disconnect since there will be no more events sent along. Listen for
93+
// the request context being closed to break this loop, otherwise this routine will
94+
// be left hanging in the background.
6095
case <-s.Context().Done():
61-
_ = handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
96+
cancel()
6297
break
6398
}
6499
}()
65100

66-
for {
67-
j := websocket.Message{}
101+
// Due to how websockets are handled we need to connect to the socket
102+
// and _then_ abort it if the server is suspended. You cannot capture
103+
// the HTTP response in the websocket client, thus we connect and then
104+
// immediately close with failure.
105+
if s.IsSuspended() {
106+
_ = handler.Connection.WriteMessage(ws.CloseMessage, ws.FormatCloseMessage(4409, "server is suspended"))
107+
108+
return
109+
}
68110

69-
_, p, err := handler.Connection.ReadMessage()
111+
// There is a separate rate limiter that applies to individual message types
112+
// within the actual websocket logic handler. _This_ rate limiter just exists
113+
// to avoid enormous floods of data through the socket since we need to parse
114+
// JSON each time. This rate limit realistically should never be hit since this
115+
// would require sending 50+ messages a second over the websocket (no more than
116+
// 10 per 200ms).
117+
var throttled bool
118+
rl := rate.NewLimiter(rate.Every(time.Millisecond*200), 10)
119+
120+
for {
121+
t, p, err := handler.Connection.ReadMessage()
70122
if err != nil {
71123
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
72124
handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
73125
}
74126
break
75127
}
76128

129+
if !rl.Allow() {
130+
if !throttled {
131+
throttled = true
132+
_ = handler.Connection.WriteJSON(websocket.Message{Event: websocket.ThrottledEvent, Args: []string{"global"}})
133+
}
134+
continue
135+
}
136+
137+
throttled = false
138+
139+
// If the message isn't a format we expect, or the length of the message is far larger
140+
// than we'd ever expect, drop it. The websocket upgrader logic does enforce a maximum
141+
// _compressed_ message size of 4Kb but that could decompress to a much larger amount
142+
// of data.
143+
if t != ws.TextMessage || len(p) > 32_768 {
144+
continue
145+
}
146+
77147
// Discard and JSON parse errors into the void and don't continue processing this
78148
// specific socket request. If we did a break here the client would get disconnected
79149
// from the socket, which is NOT what we want to do.
150+
var j websocket.Message
80151
if err := json.Unmarshal(p, &j); err != nil {
81152
continue
82153
}
83154

84155
go func(msg websocket.Message) {
85156
if err := handler.HandleInbound(ctx, msg); err != nil {
86-
_ = handler.SendErrorJson(msg, err)
157+
if errors.Is(err, server.ErrSuspended) {
158+
cancel()
159+
} else {
160+
_ = handler.SendErrorJson(msg, err)
161+
}
87162
}
88163
}(j)
89164
}

src/router/router_system.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/pyrohost/elytra/src/config"
1313
"github.com/pyrohost/elytra/src/router/middleware"
14+
"github.com/pyrohost/elytra/src/router/tokens"
1415
"github.com/pyrohost/elytra/src/server"
1516
"github.com/pyrohost/elytra/src/server/installer"
1617
"github.com/pyrohost/elytra/src/system"
@@ -164,20 +165,24 @@ func postDeauthorizeUser(c *gin.Context) {
164165
User string `json:"user"`
165166
Servers []string `json:"servers"`
166167
}
168+
167169
if err := c.BindJSON(&data); err != nil {
168170
return
169171
}
172+
170173
// todo: disconnect websockets more gracefully
171174
m := middleware.ExtractManager(c)
172175
if len(data.Servers) > 0 {
173176
for _, uuid := range data.Servers {
174177
if s, ok := m.Get(uuid); ok {
178+
tokens.DenyForServer(s.ID(), data.User)
175179
s.Websockets().CancelAll()
176180
s.Sftp().Cancel(data.User)
177181
}
178182
}
179183
} else {
180184
for _, s := range m.All() {
185+
tokens.DenyForServer(s.ID(), data.User)
181186
s.Websockets().CancelAll()
182187
s.Sftp().Cancel(data.User)
183188
}

src/router/tokens/websocket.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,20 @@ var elytraBootTime = time.Now()
2424
// This is used to allow the Panel to revoke tokens en-masse for a given user & server
2525
// combination since the JTI for tokens is just MD5(user.id + server.uuid). When a server
2626
// is booted this listing is fetched from the panel and the Websocket is dynamically updated.
27+
//
28+
// deprecated: prefer use of userDenylist
2729
var denylist sync.Map
2830

31+
var userDenylist sync.Map
32+
33+
// DenyForServer adds a user UUID to the denylist marking any existing JWTs issued
34+
// to the user as being invalid. This is associated with the user.
35+
func DenyForServer(s string, u string) {
36+
log.WithField("user_uuid", u).WithField("server_uuid", s).Debugf("denying all JWTs created at or before current time for user \"%s\"", u)
37+
38+
userDenylist.Store(strings.Join([]string{s, u}, ":"), time.Now())
39+
}
40+
2941
// Adds a JTI to the denylist by marking any JWTs generated before the current time as
3042
// being invalid if they use the same JTI.
3143
func DenyJTI(jti string) {
@@ -62,7 +74,7 @@ func (p *WebsocketPayload) GetServerUuid() string {
6274
}
6375

6476
// Check if the JWT has been marked as denied by the instance due to either being issued
65-
// before Elytra was booted, or because we have denied all tokens with the same JTI
77+
// before Wings was booted, or because we have denied all tokens with the same JTI
6678
// occurring before a set time.
6779
func (p *WebsocketPayload) Denylisted() bool {
6880
// If there is no IssuedAt present for the token, we cannot validate the token so
@@ -71,20 +83,29 @@ func (p *WebsocketPayload) Denylisted() bool {
7183
return true
7284
}
7385

74-
// If the time that the token was issued is before the time at which Elytra was booted
86+
// If the time that the token was issued is before the time at which Wings was booted
7587
// then the token is invalid for our purposes, even if the token "has permission".
7688
if p.IssuedAt.Time.Before(elytraBootTime) {
7789
return true
7890
}
7991

8092
// Finally, if the token was issued before a time that is currently denied for this
8193
// token instance, ignore the permissions response.
94+
//
95+
// This list is deprecated, but we maintain the check here so that custom instances
96+
// are able to continue working. We'll remove it in a future release.
8297
if t, ok := denylist.Load(p.JWTID); ok {
8398
if p.IssuedAt.Time.Before(t.(time.Time)) {
8499
return true
85100
}
86101
}
87102

103+
if t, ok := userDenylist.Load(strings.Join([]string{p.ServerUUID, p.UserUUID}, ":")); ok {
104+
if p.IssuedAt.Time.Before(t.(time.Time)) {
105+
return true
106+
}
107+
}
108+
88109
return false
89110
}
90111

src/router/websocket/limiter.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package websocket
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"golang.org/x/time/rate"
8+
)
9+
10+
type LimiterBucket struct {
11+
mu sync.RWMutex
12+
limits map[Event]*rate.Limiter
13+
throttles map[Event]bool
14+
}
15+
16+
func (h *Handler) IsThrottled(e Event) bool {
17+
l := h.limiter.For(e)
18+
19+
h.limiter.mu.Lock()
20+
defer h.limiter.mu.Unlock()
21+
22+
if l.Allow() {
23+
h.limiter.throttles[e] = false
24+
25+
return false
26+
}
27+
28+
// If not allowed, track the throttling and send an event over the wire
29+
// if one wasn't already sent in the same throttling period.
30+
if v, ok := h.limiter.throttles[e]; !v || !ok {
31+
h.limiter.throttles[e] = true
32+
h.Logger().WithField("event", e).Debug("throttling websocket due to event volume")
33+
34+
_ = h.unsafeSendJson(&Message{Event: ThrottledEvent, Args: []string{string(e)}})
35+
}
36+
37+
return true
38+
}
39+
40+
func NewLimiter() *LimiterBucket {
41+
return &LimiterBucket{
42+
limits: make(map[Event]*rate.Limiter, 4),
43+
throttles: make(map[Event]bool, 4),
44+
}
45+
}
46+
47+
// For returns the internal rate limiter for the given event type. In most
48+
// cases this is a shared rate limiter for events, but certain "heavy" or low-frequency
49+
// events implement their own limiters.
50+
func (l *LimiterBucket) For(e Event) *rate.Limiter {
51+
name := limiterName(e)
52+
53+
l.mu.RLock()
54+
if v, ok := l.limits[name]; ok {
55+
l.mu.RUnlock()
56+
return v
57+
}
58+
59+
l.mu.RUnlock()
60+
l.mu.Lock()
61+
defer l.mu.Unlock()
62+
63+
limit, burst := limitValuesFor(e)
64+
l.limits[name] = rate.NewLimiter(limit, burst)
65+
66+
return l.limits[name]
67+
}
68+
69+
// limitValuesFor returns the underlying limit and burst value for the given event.
70+
func limitValuesFor(e Event) (rate.Limit, int) {
71+
// Twice every five seconds.
72+
if e == AuthenticationEvent || e == SendServerLogsEvent {
73+
return rate.Every(time.Second * 5), 2
74+
}
75+
76+
// 10 per second.
77+
if e == SendCommandEvent {
78+
return rate.Every(time.Second), 10
79+
}
80+
81+
// 4 per second.
82+
return rate.Every(time.Second), 4
83+
}
84+
85+
func limiterName(e Event) Event {
86+
if e == AuthenticationEvent || e == SendServerLogsEvent || e == SendCommandEvent {
87+
return e
88+
}
89+
90+
return "_default"
91+
}

src/router/websocket/listeners.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
132132
continue
133133
}
134134
var sendErr error
135-
message := Message{Event: e.Topic}
135+
message := Message{Event: Event(e.Topic)}
136136
if str, ok := e.Data.(string); ok {
137137
message.Args = []string{str}
138138
} else if b, ok := e.Data.([]byte); ok {
@@ -150,7 +150,7 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error {
150150
continue
151151
}
152152
}
153-
onError(message.Event, sendErr)
153+
onError(string(message.Event), sendErr)
154154
}
155155
break
156156
}

0 commit comments

Comments
 (0)