@@ -451,7 +451,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
451
451
return err
452
452
}
453
453
454
- func (c * Conn ) prepWrite (messageType int ) error {
454
+ // beginMessage prepares a connection and message writer for a new message.
455
+ func (c * Conn ) beginMessage (mw * messageWriter , messageType int ) error {
455
456
// Close previous writer if not already closed by the application. It's
456
457
// probably better to return an error in this situation, but we cannot
457
458
// change this without breaking existing applications.
@@ -471,6 +472,10 @@ func (c *Conn) prepWrite(messageType int) error {
471
472
return err
472
473
}
473
474
475
+ mw .c = c
476
+ mw .frameType = messageType
477
+ mw .pos = maxFrameHeaderSize
478
+
474
479
if c .writeBuf == nil {
475
480
wpd , ok := c .writePool .Get ().(writePoolData )
476
481
if ok {
@@ -491,16 +496,11 @@ func (c *Conn) prepWrite(messageType int) error {
491
496
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
492
497
// PongMessage) are supported.
493
498
func (c * Conn ) NextWriter (messageType int ) (io.WriteCloser , error ) {
494
- if err := c .prepWrite (messageType ); err != nil {
499
+ var mw messageWriter
500
+ if err := c .beginMessage (& mw , messageType ); err != nil {
495
501
return nil , err
496
502
}
497
-
498
- mw := & messageWriter {
499
- c : c ,
500
- frameType : messageType ,
501
- pos : maxFrameHeaderSize ,
502
- }
503
- c .writer = mw
503
+ c .writer = & mw
504
504
if c .newCompressionWriter != nil && c .enableWriteCompression && isData (messageType ) {
505
505
w := c .newCompressionWriter (c .writer , c .compressionLevel )
506
506
mw .compress = true
@@ -517,10 +517,16 @@ type messageWriter struct {
517
517
err error
518
518
}
519
519
520
- func (w * messageWriter ) fatal (err error ) error {
520
+ func (w * messageWriter ) endMessage (err error ) error {
521
521
if w .err != nil {
522
- w .err = err
523
- w .c .writer = nil
522
+ return err
523
+ }
524
+ c := w .c
525
+ w .err = err
526
+ c .writer = nil
527
+ if c .writePool != nil {
528
+ c .writePool .Put (writePoolData {buf : c .writeBuf })
529
+ c .writeBuf = nil
524
530
}
525
531
return err
526
532
}
@@ -534,7 +540,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
534
540
// Check for invalid control frames.
535
541
if isControl (w .frameType ) &&
536
542
(! final || length > maxControlFramePayloadSize ) {
537
- return w .fatal (errInvalidControlFrame )
543
+ return w .endMessage (errInvalidControlFrame )
538
544
}
539
545
540
546
b0 := byte (w .frameType )
@@ -579,7 +585,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
579
585
copy (c .writeBuf [maxFrameHeaderSize - 4 :], key [:])
580
586
maskBytes (key , 0 , c .writeBuf [maxFrameHeaderSize :w .pos ])
581
587
if len (extra ) > 0 {
582
- return c .writeFatal (errors .New ("websocket: internal error, extra used in client mode" ))
588
+ return w . endMessage ( c .writeFatal (errors .New ("websocket: internal error, extra used in client mode" ) ))
583
589
}
584
590
}
585
591
@@ -600,15 +606,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
600
606
c .isWriting = false
601
607
602
608
if err != nil {
603
- return w .fatal (err )
609
+ return w .endMessage (err )
604
610
}
605
611
606
612
if final {
607
- c .writer = nil
608
- if c .writePool != nil {
609
- c .writePool .Put (writePoolData {buf : c .writeBuf })
610
- c .writeBuf = nil
611
- }
613
+ w .endMessage (errWriteClosed )
612
614
return nil
613
615
}
614
616
@@ -709,7 +711,6 @@ func (w *messageWriter) Close() error {
709
711
if err := w .flushFrame (true , nil ); err != nil {
710
712
return err
711
713
}
712
- w .err = errWriteClosed
713
714
return nil
714
715
}
715
716
@@ -742,10 +743,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
742
743
if c .isServer && (c .newCompressionWriter == nil || ! c .enableWriteCompression ) {
743
744
// Fast path with no allocations and single frame.
744
745
745
- if err := c .prepWrite (messageType ); err != nil {
746
+ var mw messageWriter
747
+ if err := c .beginMessage (& mw , messageType ); err != nil {
746
748
return err
747
749
}
748
- mw := messageWriter {c : c , frameType : messageType , pos : maxFrameHeaderSize }
749
750
n := copy (c .writeBuf [mw .pos :], data )
750
751
mw .pos += n
751
752
data = data [n :]
0 commit comments