Skip to content

Commit ce5d2b8

Browse files
authored
feat: add transport compression (#476)
1 parent d5695ad commit ce5d2b8

File tree

7 files changed

+355
-54
lines changed

7 files changed

+355
-54
lines changed

_examples/proxy/example.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ func main() {
5151
client, err := disgo.New(token,
5252
bot.WithShardManagerConfigOpts(
5353
sharding.WithGatewayConfigOpts( // gateway intents are set in the proxy not here
54-
gateway.WithURL(gatewayURL), // set the custom gateway url
55-
gateway.WithCompress(false), // we don't want compression as that would be additional overhead
54+
gateway.WithURL(gatewayURL), // set the custom gateway url
55+
gateway.WithCompression(gateway.CompressionNone), // we don't want compression as that would be additional overhead
5656
),
5757
sharding.WithIdentifyRateLimiter(gateway.NewNoopIdentifyRateLimiter()), // disable sharding rate limiter as the proxy handles it
5858
),

_examples/sharding/example.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func main() {
3030
sharding.WithAutoScaling(true),
3131
sharding.WithGatewayConfigOpts(
3232
gateway.WithIntents(gateway.IntentGuilds, gateway.IntentGuildMessages, gateway.IntentDirectMessages),
33-
gateway.WithCompress(true),
33+
gateway.WithCompression(gateway.CompressionZstdStream),
3434
),
3535
),
3636
bot.WithEventListeners(&events.ListenerAdapter{

gateway/gateway.go

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@ package gateway
22

33
import (
44
"bytes"
5-
"compress/zlib"
65
"context"
76
"errors"
87
"fmt"
98
"io"
109
"log/slog"
1110
"math/rand/v2"
1211
"net"
12+
"net/url"
13+
"strconv"
1314
"sync"
1415
"syscall"
1516
"time"
1617

17-
"github.com/disgoorg/json/v2"
1818
"github.com/gorilla/websocket"
1919

2020
"github.com/disgoorg/disgo/discord"
@@ -171,7 +171,7 @@ type gatewayImpl struct {
171171
closeHandlerFunc CloseHandlerFunc
172172
token string
173173

174-
conn *websocket.Conn
174+
conn transport
175175
connMu sync.Mutex
176176
heartbeatCancel context.CancelFunc
177177
status Status
@@ -211,7 +211,7 @@ func (g *gatewayImpl) Open(ctx context.Context) error {
211211
}
212212

213213
func (g *gatewayImpl) open(ctx context.Context) error {
214-
g.config.Logger.DebugContext(ctx, "opening gateway connection")
214+
g.config.Logger.DebugContext(ctx, "opening gateway connection", slog.String("compression", g.config.Compression.String()))
215215

216216
g.connMu.Lock()
217217
if g.conn != nil {
@@ -235,7 +235,17 @@ func (g *gatewayImpl) open(ctx context.Context) error {
235235
if g.config.ResumeURL != nil && g.config.EnableResumeURL {
236236
wsURL = *g.config.ResumeURL
237237
}
238-
gatewayURL := fmt.Sprintf("%s?v=%d&encoding=json", wsURL, Version)
238+
239+
values := url.Values{}
240+
values.Set("v", strconv.Itoa(Version))
241+
values.Set("encoding", "json")
242+
243+
if g.config.Compression.IsStreamCompression() {
244+
values.Set("compress", string(g.config.Compression))
245+
}
246+
247+
gatewayURL := wsURL + "?" + values.Encode()
248+
239249
g.lastHeartbeatSent = time.Now().UTC()
240250
conn, rs, err := g.config.Dialer.DialContext(ctx, gatewayURL, nil)
241251
if err != nil {
@@ -263,7 +273,8 @@ func (g *gatewayImpl) open(ctx context.Context) error {
263273
return nil
264274
})
265275

266-
g.conn = conn
276+
t := newTransport(g.config.Compression, conn, g.config.Logger)
277+
g.conn = t
267278
g.connMu.Unlock()
268279

269280
// reset rate limiter when connecting
@@ -275,7 +286,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
275286

276287
var readyOnce sync.Once
277288
readyChan := make(chan error)
278-
go g.listen(conn, func(err error) {
289+
go g.listen(t, func(err error) {
279290
readyOnce.Do(func() {
280291
readyChan <- err
281292
close(readyChan)
@@ -315,7 +326,7 @@ func (g *gatewayImpl) CloseWithCode(ctx context.Context, code int, message strin
315326
if g.conn != nil {
316327
g.config.RateLimiter.Close(ctx)
317328
g.config.Logger.DebugContext(ctx, "closing gateway connection", slog.Int("code", code), slog.String("message", message))
318-
if err := g.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, message)); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
329+
if err := g.conn.WriteClose(code, message); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
319330
g.config.Logger.DebugContext(ctx, "error writing close code", slog.Any("err", err))
320331
}
321332
_ = g.conn.Close()
@@ -350,17 +361,11 @@ func (g *gatewayImpl) Send(ctx context.Context, op Opcode, d MessageData) error
350361
}
351362

352363
func (g *gatewayImpl) sendInternal(ctx context.Context, op Opcode, d MessageData) error {
353-
data, err := json.Marshal(Message{
364+
data := Message{
354365
Op: op,
355366
D: d,
356-
})
357-
if err != nil {
358-
return err
359367
}
360-
return g.send(ctx, websocket.TextMessage, data)
361-
}
362368

363-
func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) error {
364369
g.connMu.Lock()
365370
defer g.connMu.Unlock()
366371
if g.conn == nil {
@@ -372,8 +377,7 @@ func (g *gatewayImpl) send(ctx context.Context, messageType int, data []byte) er
372377
}
373378

374379
defer g.config.RateLimiter.Unlock()
375-
g.config.Logger.DebugContext(ctx, "sending gateway command", slog.String("data", string(data)))
376-
return g.conn.WriteMessage(messageType, data)
380+
return g.conn.WriteMessage(data)
377381
}
378382

379383
func (g *gatewayImpl) Latency() time.Duration {
@@ -516,7 +520,7 @@ func (g *gatewayImpl) identify() error {
516520
Browser: g.config.Browser,
517521
Device: g.config.Device,
518522
},
519-
Compress: g.config.Compress,
523+
Compress: g.config.Compression.IsPayloadCompression(),
520524
LargeThreshold: g.config.LargeThreshold,
521525
Intents: g.config.Intents,
522526
Presence: g.config.Presence,
@@ -554,14 +558,14 @@ func (g *gatewayImpl) resume() error {
554558
return nil
555559
}
556560

557-
func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
561+
func (g *gatewayImpl) listen(conn transport, ready func(error)) {
558562
defer g.config.Logger.Debug("exiting listen goroutine")
559563

560564
// Ensure that we never leave this function without calling ready
561565
defer ready(nil)
562566

563567
for {
564-
mt, r, err := conn.NextReader()
568+
message, err := conn.ReceiveMessage()
565569
if err != nil {
566570
g.statusMu.Lock()
567571
if g.status != StatusReady {
@@ -620,10 +624,8 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
620624

621625
return
622626
}
623-
624-
message, err := g.parseMessage(mt, r)
625-
if err != nil {
626-
g.config.Logger.Error("error while parsing gateway message", slog.Any("err", err))
627+
if message == nil {
628+
// No message (probably parsing error), just continue as the transport already logged it
627629
continue
628630
}
629631

@@ -746,30 +748,3 @@ func (g *gatewayImpl) listen(conn *websocket.Conn, ready func(error)) {
746748
}
747749
}
748750
}
749-
750-
func (g *gatewayImpl) parseMessage(mt int, r io.Reader) (Message, error) {
751-
if mt == websocket.BinaryMessage {
752-
g.config.Logger.Debug("binary message received. decompressing")
753-
754-
reader, err := zlib.NewReader(r)
755-
if err != nil {
756-
return Message{}, fmt.Errorf("failed to decompress zlib: %w", err)
757-
}
758-
defer reader.Close()
759-
r = reader
760-
}
761-
762-
if g.config.Logger.Enabled(context.Background(), slog.LevelDebug) {
763-
buff := new(bytes.Buffer)
764-
tr := io.TeeReader(r, buff)
765-
data, err := io.ReadAll(tr)
766-
if err != nil {
767-
return Message{}, fmt.Errorf("failed to read message: %w", err)
768-
}
769-
g.config.Logger.Debug("received gateway message", slog.String("data", string(data)))
770-
r = buff
771-
}
772-
773-
var message Message
774-
return message, json.NewDecoder(r).Decode(&message)
775-
}

gateway/gateway_config.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ func defaultConfig() config {
1313
LargeThreshold: 50,
1414
Intents: IntentsDefault,
1515
Compress: true,
16+
Compression: CompressionZstdStream,
1617
URL: "wss://gateway.discord.gg",
1718
ShardID: 0,
1819
ShardCount: 1,
@@ -33,7 +34,10 @@ type config struct {
3334
// Intents is the Intents for the Gateway. Defaults to IntentsNone.
3435
Intents Intents
3536
// Compress is whether the Gateway should compress payloads. Defaults to true.
37+
// Deprecated: Use Compression instead
3638
Compress bool
39+
// Compression is the compression type to use for the gateway. Defaults to ZstdCompression.
40+
Compression CompressionType
3741
// URL is the URL of the Gateway. Defaults to fetch from Discord.
3842
URL string
3943
// ShardID is the shardID of the Gateway. Defaults to 0.
@@ -118,12 +122,31 @@ func WithIntents(intents ...Intents) ConfigOpt {
118122

119123
// WithCompress sets whether this Gateway supports compression.
120124
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
125+
// Deprecated: Use WithCompression instead
121126
func WithCompress(compress bool) ConfigOpt {
122127
return func(config *config) {
128+
if compress {
129+
config.Compression = CompressionZlibPayload
130+
} else {
131+
config.Compression = CompressionNone
132+
}
133+
134+
// Set the deprecated field too
123135
config.Compress = compress
124136
}
125137
}
126138

139+
// WithCompression sets the compression mechanism to use.
140+
// See here for more information: https://discord.com/developers/docs/topics/gateway#encoding-and-compression
141+
func WithCompression(compression CompressionType) ConfigOpt {
142+
return func(config *config) {
143+
config.Compression = compression
144+
145+
// Set the deprecated field too
146+
config.Compress = compression == CompressionZlibPayload
147+
}
148+
}
149+
127150
// WithURL sets the Gateway URL for the Gateway.
128151
func WithURL(url string) ConfigOpt {
129152
return func(config *config) {

0 commit comments

Comments
 (0)