Skip to content

Commit b6b56b7

Browse files
committed
Both modes seem to work :)
1 parent 0f115ed commit b6b56b7

File tree

9 files changed

+196
-142
lines changed

9 files changed

+196
-142
lines changed

accept.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
111111
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
112112

113113
return newConn(connConfig{
114-
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
115-
rwc: netConn,
116-
client: false,
117-
copts: copts,
118-
br: brw.Reader,
119-
bw: brw.Writer,
114+
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
115+
rwc: netConn,
116+
client: false,
117+
copts: copts,
118+
flateThreshold: opts.CompressionOptions.Threshold,
119+
120+
br: brw.Reader,
121+
bw: brw.Writer,
120122
}), nil
121123
}
122124

assert_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strings"
77
"testing"
88

9+
"cdr.dev/slog"
910
"cdr.dev/slog/sloggers/slogtest/assert"
1011

1112
"nhooyr.io/websocket"
@@ -33,7 +34,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int)
3334
}
3435

3536
func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
36-
t.Helper()
37+
slog.Helper()
3738

3839
var act interface{}
3940
err := wsjson.Read(ctx, c, &act)

compress.go

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
148148

149149
var flateReaderPool sync.Pool
150150

151-
func getFlateReader(r io.Reader) io.Reader {
151+
func getFlateReader(r io.Reader, dict []byte) io.Reader {
152152
fr, ok := flateReaderPool.Get().(io.Reader)
153153
if !ok {
154-
return flate.NewReader(r)
154+
return flate.NewReaderDict(r, dict)
155155
}
156-
fr.(flate.Resetter).Reset(r, nil)
156+
fr.(flate.Resetter).Reset(r, dict)
157157
return fr
158158
}
159159

@@ -163,10 +163,10 @@ func putFlateReader(fr io.Reader) {
163163

164164
var flateWriterPool sync.Pool
165165

166-
func getFlateWriter(w io.Writer, dict []byte) *flate.Writer {
166+
func getFlateWriter(w io.Writer) *flate.Writer {
167167
fw, ok := flateWriterPool.Get().(*flate.Writer)
168168
if !ok {
169-
fw, _ = flate.NewWriterDict(w, flate.BestSpeed, dict)
169+
fw, _ = flate.NewWriter(w, flate.BestSpeed)
170170
return fw
171171
}
172172
fw.Reset(w)
@@ -177,40 +177,32 @@ func putFlateWriter(w *flate.Writer) {
177177
flateWriterPool.Put(w)
178178
}
179179

180-
type slidingWindowReader struct {
181-
window []byte
182-
183-
r io.Reader
180+
type slidingWindow struct {
181+
r io.Reader
182+
buf []byte
184183
}
185184

186-
func (r slidingWindowReader) Read(p []byte) (int, error) {
187-
n, err := r.r.Read(p)
188-
p = p[:n]
189-
190-
r.append(p)
191-
192-
return n, err
185+
func newSlidingWindow(n int) *slidingWindow {
186+
return &slidingWindow{
187+
buf: make([]byte, 0, n),
188+
}
193189
}
194190

195-
func (r slidingWindowReader) append(p []byte) {
196-
if len(r.window) <= cap(r.window) {
197-
r.window = append(r.window, p...)
191+
func (w *slidingWindow) write(p []byte) {
192+
if len(p) >= cap(w.buf) {
193+
w.buf = w.buf[:cap(w.buf)]
194+
p = p[len(p)-cap(w.buf):]
195+
copy(w.buf, p)
196+
return
198197
}
199198

200-
if len(p) > cap(r.window) {
201-
p = p[len(p)-cap(r.window):]
199+
left := cap(w.buf) - len(w.buf)
200+
if left < len(p) {
201+
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
202+
spaceNeeded := len(p) - left
203+
copy(w.buf, w.buf[spaceNeeded:])
204+
w.buf = w.buf[:len(w.buf)-spaceNeeded]
202205
}
203206

204-
// p now contains at max the last window bytes
205-
// so we need to be able to append all of it to r.window.
206-
// Shift as many bytes from r.window as needed.
207-
208-
// Maximum window size minus current window minus extra gives
209-
// us the number of bytes that need to be shifted.
210-
off := len(r.window) + len(p) - cap(r.window)
211-
212-
r.window = append(r.window[:0], r.window[off:]...)
213-
copy(r.window, r.window[off:])
214-
copy(r.window[len(r.window)-len(p):], p)
215-
return
207+
w.buf = append(w.buf, p...)
216208
}

compress_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package websocket
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/base64"
6+
"math/big"
7+
"strings"
8+
"testing"
9+
10+
"cdr.dev/slog/sloggers/slogtest/assert"
11+
)
12+
13+
func Test_slidingWindow(t *testing.T) {
14+
t.Parallel()
15+
16+
const testCount = 99
17+
const maxWindow = 99999
18+
for i := 0; i < testCount; i++ {
19+
input := randStr(t, maxWindow)
20+
windowLength := randInt(t, maxWindow)
21+
r := newSlidingWindow(windowLength)
22+
r.write([]byte(input))
23+
24+
if cap(r.buf) != windowLength {
25+
t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength)
26+
}
27+
assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf)))
28+
}
29+
}
30+
31+
func randStr(t *testing.T, max int) string {
32+
n := randInt(t, max)
33+
34+
b := make([]byte, n)
35+
_, err := rand.Read(b)
36+
assert.Success(t, "rand.Read", err)
37+
38+
return base64.StdEncoding.EncodeToString(b)
39+
}
40+
41+
func randInt(t *testing.T, max int) int {
42+
x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
43+
assert.Success(t, "rand.Int", err)
44+
return int(x.Int64())
45+
}

conn.go

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ const (
3838
// On any error from any method, the connection is closed
3939
// with an appropriate reason.
4040
type Conn struct {
41-
subprotocol string
42-
rwc io.ReadWriteCloser
43-
client bool
44-
copts *compressionOptions
45-
br *bufio.Reader
46-
bw *bufio.Writer
41+
subprotocol string
42+
rwc io.ReadWriteCloser
43+
client bool
44+
copts *compressionOptions
45+
flateThreshold int
46+
br *bufio.Reader
47+
bw *bufio.Writer
4748

4849
readTimeout chan context.Context
4950
writeTimeout chan context.Context
@@ -71,21 +72,23 @@ type Conn struct {
7172
}
7273

7374
type connConfig struct {
74-
subprotocol string
75-
rwc io.ReadWriteCloser
76-
client bool
77-
copts *compressionOptions
75+
subprotocol string
76+
rwc io.ReadWriteCloser
77+
client bool
78+
copts *compressionOptions
79+
flateThreshold int
7880

7981
br *bufio.Reader
8082
bw *bufio.Writer
8183
}
8284

8385
func newConn(cfg connConfig) *Conn {
8486
c := &Conn{
85-
subprotocol: cfg.subprotocol,
86-
rwc: cfg.rwc,
87-
client: cfg.client,
88-
copts: cfg.copts,
87+
subprotocol: cfg.subprotocol,
88+
rwc: cfg.rwc,
89+
client: cfg.client,
90+
copts: cfg.copts,
91+
flateThreshold: cfg.flateThreshold,
8992

9093
br: cfg.br,
9194
bw: cfg.bw,
@@ -96,6 +99,12 @@ func newConn(cfg connConfig) *Conn {
9699
closed: make(chan struct{}),
97100
activePings: make(map[string]chan<- struct{}),
98101
}
102+
if c.flateThreshold == 0 {
103+
c.flateThreshold = 256
104+
if c.writeNoContextTakeOver() {
105+
c.flateThreshold = 512
106+
}
107+
}
99108

100109
c.readMu = newMu(c)
101110
c.writeFrameMu = newMu(c)
@@ -145,12 +154,10 @@ func (c *Conn) close(err error) {
145154
}
146155
c.msgWriter.close()
147156

157+
c.msgReader.close()
148158
if c.client {
149-
c.readMu.Lock(context.Background())
150159
putBufioReader(c.br)
151-
c.readMu.Unlock()
152160
}
153-
c.msgReader.close()
154161
}()
155162
}
156163

conn_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ func TestConn(t *testing.T) {
2727
Subprotocols: []string{"echo"},
2828
InsecureSkipVerify: true,
2929
CompressionOptions: websocket.CompressionOptions{
30-
Mode: websocket.CompressionNoContextTakeover,
30+
Mode: websocket.CompressionContextTakeover,
31+
Threshold: 1,
3132
},
3233
})
3334
assert.Success(t, "accept", err)
3435
defer c.Close(websocket.StatusInternalError, "")
3536

3637
err = echoLoop(r.Context(), c)
38+
t.Logf("server: %v", err)
3739
assertCloseStatus(t, websocket.StatusNormalClosure, err)
3840
}, false)
3941
defer closeFn()
@@ -46,7 +48,8 @@ func TestConn(t *testing.T) {
4648
opts := &websocket.DialOptions{
4749
Subprotocols: []string{"echo"},
4850
CompressionOptions: websocket.CompressionOptions{
49-
Mode: websocket.CompressionNoContextTakeover,
51+
Mode: websocket.CompressionContextTakeover,
52+
Threshold: 1,
5053
},
5154
}
5255
opts.HTTPClient = s.Client()

dial.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
9999
}
100100

101101
return newConn(connConfig{
102-
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
103-
rwc: rwc,
104-
client: true,
105-
copts: copts,
106-
br: getBufioReader(rwc),
107-
bw: getBufioWriter(rwc),
102+
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
103+
rwc: rwc,
104+
client: true,
105+
copts: copts,
106+
flateThreshold: opts.CompressionOptions.Threshold,
107+
br: getBufioReader(rwc),
108+
bw: getBufioWriter(rwc),
108109
}), resp, nil
109110
}
110111

0 commit comments

Comments
 (0)