Skip to content

Commit d092686

Browse files
committed
Autobahn tests fully pass :)
1 parent 78da35e commit d092686

File tree

6 files changed

+127
-180
lines changed

6 files changed

+127
-180
lines changed

assert_test.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"crypto/rand"
66
"fmt"
77
"net/http"
8-
"net/http/httptest"
98
"strings"
109
"testing"
1110
"time"
@@ -108,20 +107,6 @@ func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts
108107
return c
109108
}
110109

111-
func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) {
112-
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
113-
defer cancel()
114-
115-
if opts == nil {
116-
opts = &websocket.DialOptions{}
117-
}
118-
opts.HTTPClient = s.Client()
119-
120-
c, resp, err := websocket.Dial(ctx, wsURL(s), opts)
121-
assert.Success(t, "websocket.Dial", err)
122-
return c, resp
123-
}
124-
125110
func slogType(v interface{}) slog.Field {
126111
return slog.F("type", fmt.Sprintf("%T", v))
127112
}

autobahn_test.go

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ import (
88
"fmt"
99
"io/ioutil"
1010
"net"
11-
"net/http"
12-
"net/http/httptest"
13-
"os"
1411
"os/exec"
1512
"strconv"
1613
"strings"
@@ -32,69 +29,14 @@ var excludedAutobahnCases = []string{
3229
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
3330
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
3431
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
35-
36-
"12.*",
37-
"13.*",
3832
}
3933

4034
var autobahnCases = []string{"*"}
4135

42-
// https://github.com/crossbario/autobahn-python/tree/master/wstest
4336
func TestAutobahn(t *testing.T) {
4437
t.Parallel()
4538

46-
if os.Getenv("AUTOBAHN") == "" {
47-
t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite")
48-
}
49-
50-
t.Run("server", testServerAutobahn)
51-
t.Run("client", testClientAutobahn)
52-
}
53-
54-
func testServerAutobahn(t *testing.T) {
55-
t.Parallel()
56-
57-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58-
c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{
59-
Subprotocols: []string{"echo"},
60-
})
61-
err := echoLoop(r.Context(), c)
62-
assertCloseStatus(t, websocket.StatusNormalClosure, err)
63-
}))
64-
closeFn := wsgrace(s.Config)
65-
defer func() {
66-
err := closeFn()
67-
assert.Success(t, "closeFn", err)
68-
}()
69-
70-
specFile, err := tempJSONFile(map[string]interface{}{
71-
"outdir": "ci/out/wstestServerReports",
72-
"servers": []interface{}{
73-
map[string]interface{}{
74-
"agent": "main",
75-
"url": strings.Replace(s.URL, "http", "ws", 1),
76-
},
77-
},
78-
"cases": autobahnCases,
79-
"exclude-cases": excludedAutobahnCases,
80-
})
81-
assert.Success(t, "tempJSONFile", err)
82-
83-
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
84-
defer cancel()
85-
86-
args := []string{"--mode", "fuzzingclient", "--spec", specFile}
87-
wstest := exec.CommandContext(ctx, "wstest", args...)
88-
_, err = wstest.CombinedOutput()
89-
assert.Success(t, "wstest", err)
90-
91-
checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
92-
}
93-
94-
func testClientAutobahn(t *testing.T) {
95-
t.Parallel()
96-
97-
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
39+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
9840
defer cancel()
9941

10042
wstestURL, closeFn, err := wstestClientServer(ctx)
@@ -108,27 +50,17 @@ func testClientAutobahn(t *testing.T) {
10850
assert.Success(t, "wstestCaseCount", err)
10951

11052
t.Run("cases", func(t *testing.T) {
111-
// Max 8 cases running at a time.
112-
mu := make(chan struct{}, 8)
113-
11453
for i := 1; i <= cases; i++ {
11554
i := i
11655
t.Run("", func(t *testing.T) {
117-
t.Parallel()
118-
119-
mu <- struct{}{}
120-
defer func() {
121-
<-mu
122-
}()
123-
124-
ctx, cancel := context.WithTimeout(ctx, time.Second*45)
56+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
12557
defer cancel()
12658

12759
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
12860
assert.Success(t, "autobahn dial", err)
12961

13062
err = echoLoop(ctx, c)
131-
t.Logf("echoLoop: %+v", err)
63+
t.Logf("echoLoop: %v", err)
13264
})
13365
}
13466
})
@@ -174,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er
174106
return "", nil, xerrors.Errorf("failed to write spec: %w", err)
175107
}
176108

177-
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
109+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
178110
defer func() {
179111
if err != nil {
180112
cancel()

conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func newConn(cfg connConfig) *Conn {
9999
closed: make(chan struct{}),
100100
activePings: make(map[string]chan<- struct{}),
101101
}
102-
if c.flateThreshold == 0 {
102+
if c.flate() && c.flateThreshold == 0 {
103103
c.flateThreshold = 256
104104
if c.writeNoContextTakeOver() {
105105
c.flateThreshold = 512

conn_test.go

Lines changed: 100 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,99 +3,70 @@
33
package websocket_test
44

55
import (
6+
"bufio"
67
"context"
78
"crypto/rand"
89
"io"
910
"math/big"
11+
"net"
1012
"net/http"
1113
"net/http/httptest"
12-
"strings"
13-
"sync/atomic"
1414
"testing"
1515
"time"
1616

1717
"cdr.dev/slog/sloggers/slogtest/assert"
18-
"golang.org/x/xerrors"
1918

2019
"nhooyr.io/websocket"
2120
)
2221

22+
func goFn(fn func()) func() {
23+
done := make(chan struct{})
24+
go func() {
25+
defer close(done)
26+
fn()
27+
}()
28+
29+
return func() {
30+
<-done
31+
}
32+
}
33+
2334
func TestConn(t *testing.T) {
2435
t.Parallel()
2536

2637
t.Run("json", func(t *testing.T) {
2738
t.Parallel()
2839

29-
s, closeFn := testEchoLoop(t)
30-
defer closeFn()
40+
for i := 0; i < 1; i++ {
41+
t.Run("", func(t *testing.T) {
42+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
43+
defer cancel()
3144

32-
c, _ := dialWebSocket(t, s, nil)
33-
defer c.Close(websocket.StatusInternalError, "")
34-
35-
c.SetReadLimit(1 << 30)
36-
37-
for i := 0; i < 10; i++ {
38-
n := randInt(t, 1_048_576)
39-
echoJSON(t, c, n)
40-
}
45+
c1, c2 := websocketPipe(t)
4146

42-
c.Close(websocket.StatusNormalClosure, "")
43-
})
44-
}
45-
46-
func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) {
47-
h := http.HandlerFunc(fn)
48-
if randInt(tb, 2) == 1 {
49-
s = httptest.NewTLSServer(h)
50-
} else {
51-
s = httptest.NewServer(h)
52-
}
53-
closeFn2 := wsgrace(s.Config)
54-
return s, func() {
55-
err := closeFn2()
56-
assert.Success(tb, "closeFn", err)
57-
}
58-
}
47+
wait := goFn(func() {
48+
err := echoLoop(ctx, c1)
49+
assertCloseStatus(t, websocket.StatusNormalClosure, err)
50+
})
51+
defer wait()
5952

60-
// grace wraps s.Handler to gracefully shutdown WebSocket connections.
61-
// The returned function must be used to close the server instead of s.Close.
62-
func wsgrace(s *http.Server) (closeFn func() error) {
63-
h := s.Handler
64-
var conns int64
65-
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66-
atomic.AddInt64(&conns, 1)
67-
defer atomic.AddInt64(&conns, -1)
53+
c2.SetReadLimit(1 << 30)
6854

69-
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
70-
defer cancel()
71-
72-
r = r.WithContext(ctx)
55+
for i := 0; i < 10; i++ {
56+
n := randInt(t, 131_072)
57+
echoJSON(t, c2, n)
58+
}
7359

74-
h.ServeHTTP(w, r)
60+
c2.Close(websocket.StatusNormalClosure, "")
61+
})
62+
}
7563
})
64+
}
7665

77-
return func() error {
78-
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
79-
defer cancel()
80-
81-
err := s.Shutdown(ctx)
82-
if err != nil {
83-
return xerrors.Errorf("server shutdown failed: %v", err)
84-
}
66+
type writerFunc func(p []byte) (int, error)
8567

86-
t := time.NewTicker(time.Millisecond * 10)
87-
defer t.Stop()
88-
for {
89-
select {
90-
case <-t.C:
91-
if atomic.LoadInt64(&conns) == 0 {
92-
return nil
93-
}
94-
case <-ctx.Done():
95-
return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
96-
}
97-
}
98-
}
68+
func (f writerFunc) Write(p []byte) (int, error) {
69+
return f(p)
9970
}
10071

10172
// echoLoop echos every msg received from c until an error
@@ -133,22 +104,74 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error {
133104
}
134105
}
135106

136-
func wsURL(s *httptest.Server) string {
137-
return strings.Replace(s.URL, "http", "ws", 1)
138-
}
139-
140-
func testEchoLoop(t testing.TB) (*httptest.Server, func()) {
141-
return testServer(t, func(w http.ResponseWriter, r *http.Request) {
142-
c := acceptWebSocket(t, r, w, nil)
143-
defer c.Close(websocket.StatusInternalError, "")
144-
145-
err := echoLoop(r.Context(), c)
146-
assertCloseStatus(t, websocket.StatusNormalClosure, err)
147-
})
107+
func randBool(t testing.TB) bool {
108+
return randInt(t, 2) == 1
148109
}
149110

150111
func randInt(t testing.TB, max int) int {
151112
x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
152113
assert.Success(t, "rand.Int", err)
153114
return int(x.Int64())
154115
}
116+
117+
type testHijacker struct {
118+
*httptest.ResponseRecorder
119+
serverConn net.Conn
120+
hijacked chan struct{}
121+
}
122+
123+
var _ http.Hijacker = testHijacker{}
124+
125+
func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
126+
close(hj.hijacked)
127+
return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
128+
}
129+
130+
func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) {
131+
var serverConn *websocket.Conn
132+
tt := testTransport{
133+
h: func(w http.ResponseWriter, r *http.Request) {
134+
serverConn = acceptWebSocket(t, r, w, nil)
135+
},
136+
}
137+
138+
dialOpts := &websocket.DialOptions{
139+
HTTPClient: &http.Client{
140+
Transport: tt,
141+
},
142+
}
143+
144+
clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts)
145+
assert.Success(t, "websocket.Dial", err)
146+
147+
if randBool(t) {
148+
return serverConn, clientConn
149+
}
150+
return clientConn, serverConn
151+
}
152+
153+
type testTransport struct {
154+
h http.HandlerFunc
155+
}
156+
157+
func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) {
158+
clientConn, serverConn := net.Pipe()
159+
160+
hj := testHijacker{
161+
ResponseRecorder: httptest.NewRecorder(),
162+
serverConn: serverConn,
163+
hijacked: make(chan struct{}),
164+
}
165+
166+
done := make(chan struct{})
167+
t.h.ServeHTTP(hj, r)
168+
169+
select {
170+
case <-hj.hijacked:
171+
resp := hj.ResponseRecorder.Result()
172+
resp.Body = clientConn
173+
return resp, nil
174+
case <-done:
175+
return hj.ResponseRecorder.Result(), nil
176+
}
177+
}

0 commit comments

Comments
 (0)