Skip to content

Commit cc1c15a

Browse files
committed
fix: rewrite close handshake flow when initiated from other side
1 parent 8a6704c commit cc1c15a

File tree

5 files changed

+234
-90
lines changed

5 files changed

+234
-90
lines changed

close.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
100100
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101101
defer errd.Wrap(&err, "failed to close WebSocket")
102102

103-
if !c.casClosing() {
103+
if c.casClosing() {
104104
err = c.waitGoroutines()
105105
if err != nil {
106106
return err
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
133133
func (c *Conn) CloseNow() (err error) {
134134
defer errd.Wrap(&err, "failed to immediately close WebSocket")
135135

136-
if !c.casClosing() {
136+
if c.casClosing() {
137137
err = c.waitGoroutines()
138138
if err != nil {
139139
return err
@@ -206,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error {
206206
}
207207
defer c.readMu.unlock()
208208

209-
if c.readCloseErr != nil {
210-
return c.readCloseErr
211-
}
212-
213209
for i := int64(0); i < c.msgReader.payloadLength; i++ {
214210
_, err := c.br.ReadByte()
215211
if err != nil {
@@ -333,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
333329
}
334330

335331
func (c *Conn) casClosing() bool {
336-
c.closeMu.Lock()
337-
defer c.closeMu.Unlock()
338-
if !c.closing {
339-
c.closing = true
340-
return true
341-
}
342-
return false
332+
return c.closing.Swap(true)
343333
}
344334

345335
func (c *Conn) isClosed() bool {

conn.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,27 @@ type Conn struct {
6161
readHeaderBuf [8]byte
6262
readControlBuf [maxControlPayload]byte
6363
msgReader *msgReader
64-
readCloseErr error
6564

6665
// Write state.
6766
msgWriter *msgWriter
6867
writeFrameMu *mu
6968
writeBuf []byte
7069
writeHeaderBuf [8]byte
7170
writeHeader header
72-
closeSent bool
71+
72+
// Close handshake state.
73+
closeStateMu sync.RWMutex
74+
closeReceivedErr error
75+
closeSentErr error
7376

7477
// CloseRead state.
7578
closeReadMu sync.Mutex
7679
closeReadCtx context.Context
7780
closeReadDone chan struct{}
7881

79-
closed chan struct{}
82+
closing atomic.Bool
8083
closeMu sync.Mutex
81-
closing bool
84+
closed chan struct{}
8285

8386
pingCounter atomic.Int64
8487
activePingsMu sync.Mutex

conn_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/http/httptest"
1314
"os"
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
625626
}()
626627
}
627628
}
629+
630+
func TestConnClosePropagation(t *testing.T) {
631+
t.Parallel()
632+
633+
want := []byte("hello")
634+
keepWriting := func(c *websocket.Conn) <-chan error {
635+
return xsync.Go(func() error {
636+
for {
637+
err := c.Write(context.Background(), websocket.MessageText, want)
638+
if err != nil {
639+
return err
640+
}
641+
}
642+
})
643+
}
644+
keepReading := func(c *websocket.Conn) <-chan error {
645+
return xsync.Go(func() error {
646+
for {
647+
_, got, err := c.Read(context.Background())
648+
if err != nil {
649+
return err
650+
}
651+
if !bytes.Equal(want, got) {
652+
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
653+
}
654+
}
655+
})
656+
}
657+
checkReadErr := func(t *testing.T, err error) {
658+
var ce websocket.CloseError
659+
if errors.As(err, &ce) {
660+
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
661+
} else {
662+
assert.ErrorIs(t, net.ErrClosed, err)
663+
}
664+
}
665+
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
666+
for _, c := range conn {
667+
// Check write error.
668+
err := c.Write(context.Background(), websocket.MessageText, want)
669+
assert.ErrorIs(t, net.ErrClosed, err)
670+
671+
// Check read error (output depends on when read is called in relation to connection closure).
672+
_, _, err = c.Read(context.Background())
673+
checkReadErr(t, err)
674+
}
675+
}
676+
677+
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
678+
tt, this, other := newConnTest(t, nil, nil)
679+
680+
_ = this.CloseRead(tt.ctx)
681+
thisWriteErr := keepWriting(this)
682+
683+
_, got, err := other.Read(tt.ctx)
684+
assert.Success(t, err)
685+
assert.Equal(t, "msg", want, got)
686+
687+
err = other.Close(websocket.StatusNormalClosure, "")
688+
assert.Success(t, err)
689+
690+
select {
691+
case err := <-thisWriteErr:
692+
assert.ErrorIs(t, net.ErrClosed, err)
693+
case <-tt.ctx.Done():
694+
t.Fatal(tt.ctx.Err())
695+
}
696+
697+
checkConnErrs(t, this, other)
698+
})
699+
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
700+
tt, this, other := newConnTest(t, nil, nil)
701+
702+
_ = this.CloseRead(tt.ctx)
703+
thisWriteErr := keepWriting(this)
704+
otherReadErr := keepReading(other)
705+
706+
err := this.Close(websocket.StatusNormalClosure, "")
707+
assert.Success(t, err)
708+
709+
select {
710+
case err := <-thisWriteErr:
711+
assert.ErrorIs(t, net.ErrClosed, err)
712+
case <-tt.ctx.Done():
713+
t.Fatal(tt.ctx.Err())
714+
}
715+
716+
select {
717+
case err := <-otherReadErr:
718+
checkReadErr(t, err)
719+
case <-tt.ctx.Done():
720+
t.Fatal(tt.ctx.Err())
721+
}
722+
723+
checkConnErrs(t, this, other)
724+
})
725+
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
726+
tt, this, other := newConnTest(t, nil, nil)
727+
728+
_ = other.CloseRead(tt.ctx)
729+
errs := keepReading(this)
730+
731+
err := other.Write(tt.ctx, websocket.MessageText, want)
732+
assert.Success(t, err)
733+
734+
err = other.Close(websocket.StatusNormalClosure, "")
735+
assert.Success(t, err)
736+
737+
select {
738+
case err := <-errs:
739+
checkReadErr(t, err)
740+
case <-tt.ctx.Done():
741+
t.Fatal(tt.ctx.Err())
742+
}
743+
744+
checkConnErrs(t, this, other)
745+
})
746+
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
747+
tt, this, other := newConnTest(t, nil, nil)
748+
749+
thisReadErr := keepReading(this)
750+
otherReadErr := keepReading(other)
751+
752+
err := other.Write(tt.ctx, websocket.MessageText, want)
753+
assert.Success(t, err)
754+
755+
err = this.Close(websocket.StatusNormalClosure, "")
756+
assert.Success(t, err)
757+
758+
select {
759+
case err := <-thisReadErr:
760+
checkReadErr(t, err)
761+
case <-tt.ctx.Done():
762+
t.Fatal(tt.ctx.Err())
763+
}
764+
765+
select {
766+
case err := <-otherReadErr:
767+
checkReadErr(t, err)
768+
case <-tt.ctx.Done():
769+
t.Fatal(tt.ctx.Err())
770+
}
771+
772+
checkConnErrs(t, this, other)
773+
})
774+
}

read.go

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,6 @@ func (c *Conn) readRSV1Illegal(h header) bool {
181181
}
182182

183183
func (c *Conn) readLoop(ctx context.Context) (header, error) {
184-
if c.readCloseErr != nil {
185-
select {
186-
case <-c.closed:
187-
return header{}, net.ErrClosed
188-
default:
189-
}
190-
return header{}, c.readCloseErr
191-
}
192-
193184
for {
194185
h, err := c.readFrameHeader(ctx)
195186
if err != nil {
@@ -226,57 +217,59 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
226217
}
227218
}
228219

229-
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
220+
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
230221
select {
231222
case <-c.closed:
232-
return header{}, net.ErrClosed
223+
return nil, net.ErrClosed
233224
case c.readTimeout <- ctx:
234225
}
235226

236-
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
237-
if err != nil {
227+
c.closeStateMu.Lock()
228+
closeReceivedErr := c.closeReceivedErr
229+
c.closeStateMu.Unlock()
230+
if closeReceivedErr != nil {
231+
return nil, closeReceivedErr
232+
}
233+
234+
return func() {
238235
select {
239236
case <-c.closed:
240-
return header{}, net.ErrClosed
241-
case <-ctx.Done():
242-
return header{}, ctx.Err()
243-
default:
244-
return header{}, err
237+
if *err != nil {
238+
*err = net.ErrClosed
239+
}
240+
case c.writeTimeout <- context.Background():
245241
}
242+
if *err != nil && ctx.Err() != nil {
243+
*err = ctx.Err()
244+
}
245+
}, nil
246+
}
247+
248+
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
249+
readDone, err := c.prepareRead(ctx, &err)
250+
if err != nil {
251+
return header{}, err
246252
}
253+
defer readDone()
247254

248-
select {
249-
case <-c.closed:
250-
return header{}, net.ErrClosed
251-
case c.readTimeout <- context.Background():
255+
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
256+
if err != nil {
257+
return header{}, err
252258
}
253259

254260
return h, nil
255261
}
256262

257-
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
258-
select {
259-
case <-c.closed:
260-
return 0, net.ErrClosed
261-
case c.readTimeout <- ctx:
263+
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
264+
readDone, err := c.prepareRead(ctx, &err)
265+
if err != nil {
266+
return 0, err
262267
}
268+
defer readDone()
263269

264270
n, err := io.ReadFull(c.br, p)
265271
if err != nil {
266-
select {
267-
case <-c.closed:
268-
return n, net.ErrClosed
269-
case <-ctx.Done():
270-
return n, ctx.Err()
271-
default:
272-
return n, fmt.Errorf("failed to read frame payload: %w", err)
273-
}
274-
}
275-
276-
select {
277-
case <-c.closed:
278-
return n, net.ErrClosed
279-
case c.readTimeout <- context.Background():
272+
return n, fmt.Errorf("failed to read frame payload: %w", err)
280273
}
281274

282275
return n, err
@@ -333,18 +326,20 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
333326
return err
334327
}
335328

336-
if c.readCloseErr == nil {
337-
c.readCloseErr = ce
338-
}
339-
340329
err = fmt.Errorf("received close frame: %w", ce)
341-
if err2 := c.writeClose(ce.Code, ce.Reason); errors.Is(err2, errCloseSent) {
342-
// The close handshake has already been initiated, connection
343-
// close should be handled elsewhere.
344-
return err
330+
c.closeStateMu.Lock()
331+
c.closeReceivedErr = err
332+
closeSent := c.closeSentErr != nil
333+
c.closeStateMu.Unlock()
334+
335+
if !closeSent {
336+
c.readMu.unlock()
337+
_ = c.writeClose(ce.Code, ce.Reason)
338+
}
339+
if !c.casClosing() {
340+
c.readMu.unlock()
341+
_ = c.close()
345342
}
346-
c.readMu.unlock()
347-
c.close()
348343
return err
349344
}
350345

0 commit comments

Comments
 (0)