Skip to content

Commit 3130e8d

Browse files
Steven Scottgaryburd
authored andcommitted
Return write buffer to pool on write error (#427)
Fix bug where connection did not return the write buffer to the pool after a write error. Add test for the same. Rename messsageWriter.fatal method to endMessage and consolidate all message cleanup code there. This ensures that the buffer is returned to pool on all code paths. Rename Conn.prepMessage to beginMessage for symmetry with endMessage. Move some duplicated code at calls to prepMessage to beginMessage. Bonus improvement: Adjust message and buffer size in TestWriteBufferPool to test that pool works with fragmented messages.
1 parent cdd40f5 commit 3130e8d

File tree

2 files changed

+81
-26
lines changed

2 files changed

+81
-26
lines changed

conn.go

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
451451
return err
452452
}
453453

454-
func (c *Conn) prepWrite(messageType int) error {
454+
// beginMessage prepares a connection and message writer for a new message.
455+
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
455456
// Close previous writer if not already closed by the application. It's
456457
// probably better to return an error in this situation, but we cannot
457458
// change this without breaking existing applications.
@@ -471,6 +472,10 @@ func (c *Conn) prepWrite(messageType int) error {
471472
return err
472473
}
473474

475+
mw.c = c
476+
mw.frameType = messageType
477+
mw.pos = maxFrameHeaderSize
478+
474479
if c.writeBuf == nil {
475480
wpd, ok := c.writePool.Get().(writePoolData)
476481
if ok {
@@ -491,16 +496,11 @@ func (c *Conn) prepWrite(messageType int) error {
491496
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
492497
// PongMessage) are supported.
493498
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
494-
if err := c.prepWrite(messageType); err != nil {
499+
var mw messageWriter
500+
if err := c.beginMessage(&mw, messageType); err != nil {
495501
return nil, err
496502
}
497-
498-
mw := &messageWriter{
499-
c: c,
500-
frameType: messageType,
501-
pos: maxFrameHeaderSize,
502-
}
503-
c.writer = mw
503+
c.writer = &mw
504504
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
505505
w := c.newCompressionWriter(c.writer, c.compressionLevel)
506506
mw.compress = true
@@ -517,10 +517,16 @@ type messageWriter struct {
517517
err error
518518
}
519519

520-
func (w *messageWriter) fatal(err error) error {
520+
func (w *messageWriter) endMessage(err error) error {
521521
if w.err != nil {
522-
w.err = err
523-
w.c.writer = nil
522+
return err
523+
}
524+
c := w.c
525+
w.err = err
526+
c.writer = nil
527+
if c.writePool != nil {
528+
c.writePool.Put(writePoolData{buf: c.writeBuf})
529+
c.writeBuf = nil
524530
}
525531
return err
526532
}
@@ -534,7 +540,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
534540
// Check for invalid control frames.
535541
if isControl(w.frameType) &&
536542
(!final || length > maxControlFramePayloadSize) {
537-
return w.fatal(errInvalidControlFrame)
543+
return w.endMessage(errInvalidControlFrame)
538544
}
539545

540546
b0 := byte(w.frameType)
@@ -579,7 +585,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
579585
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
580586
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
581587
if len(extra) > 0 {
582-
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
588+
return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
583589
}
584590
}
585591

@@ -600,15 +606,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
600606
c.isWriting = false
601607

602608
if err != nil {
603-
return w.fatal(err)
609+
return w.endMessage(err)
604610
}
605611

606612
if final {
607-
c.writer = nil
608-
if c.writePool != nil {
609-
c.writePool.Put(writePoolData{buf: c.writeBuf})
610-
c.writeBuf = nil
611-
}
613+
w.endMessage(errWriteClosed)
612614
return nil
613615
}
614616

@@ -709,7 +711,6 @@ func (w *messageWriter) Close() error {
709711
if err := w.flushFrame(true, nil); err != nil {
710712
return err
711713
}
712-
w.err = errWriteClosed
713714
return nil
714715
}
715716

@@ -742,10 +743,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
742743
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
743744
// Fast path with no allocations and single frame.
744745

745-
if err := c.prepWrite(messageType); err != nil {
746+
var mw messageWriter
747+
if err := c.beginMessage(&mw, messageType); err != nil {
746748
return err
747749
}
748-
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
749750
n := copy(c.writeBuf[mw.pos:], data)
750751
mw.pos += n
751752
data = data[n:]

conn_test.go

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
196196
}
197197

198198
func TestWriteBufferPool(t *testing.T) {
199+
const message = "Now is the time for all good people to come to the aid of the party."
200+
199201
var buf bytes.Buffer
200202
var pool simpleBufferPool
201-
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
202203
rc := newTestConn(&buf, nil, false)
203204

205+
// Specify writeBufferSize smaller than message size to ensure that pooling
206+
// works with fragmented messages.
207+
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
208+
204209
if wc.writeBuf != nil {
205210
t.Fatal("writeBuf not nil after create")
206211
}
@@ -218,8 +223,6 @@ func TestWriteBufferPool(t *testing.T) {
218223

219224
writeBufAddr := &wc.writeBuf[0]
220225

221-
const message = "Hello World!"
222-
223226
if _, err := io.WriteString(w, message); err != nil {
224227
t.Fatalf("io.WriteString(w, message) returned %v", err)
225228
}
@@ -269,6 +272,7 @@ func TestWriteBufferPool(t *testing.T) {
269272
}
270273
}
271274

275+
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
272276
func TestWriteBufferPoolSync(t *testing.T) {
273277
var buf bytes.Buffer
274278
var pool sync.Pool
@@ -290,6 +294,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
290294
}
291295
}
292296

297+
// errorWriter is an io.Writer than returns an error on all writes.
298+
type errorWriter struct{}
299+
300+
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
301+
302+
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
303+
// on write.
304+
func TestWriteBufferPoolError(t *testing.T) {
305+
306+
// Part 1: Test NextWriter/Write/Close
307+
308+
var pool simpleBufferPool
309+
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
310+
311+
w, err := wc.NextWriter(TextMessage)
312+
if err != nil {
313+
t.Fatalf("wc.NextWriter() returned %v", err)
314+
}
315+
316+
if wc.writeBuf == nil {
317+
t.Fatal("writeBuf is nil after NextWriter")
318+
}
319+
320+
writeBufAddr := &wc.writeBuf[0]
321+
322+
if _, err := io.WriteString(w, "Hello"); err != nil {
323+
t.Fatalf("io.WriteString(w, message) returned %v", err)
324+
}
325+
326+
if err := w.Close(); err == nil {
327+
t.Fatalf("w.Close() did not return error")
328+
}
329+
330+
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
331+
t.Fatal("writeBuf not returned to pool")
332+
}
333+
334+
// Part 2: Test WriteMessage
335+
336+
wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
337+
338+
if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
339+
t.Fatalf("wc.WriteMessage did not return error")
340+
}
341+
342+
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
343+
t.Fatal("writeBuf not returned to pool")
344+
}
345+
}
346+
293347
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
294348
const bufSize = 512
295349

0 commit comments

Comments
 (0)