@@ -851,6 +851,13 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
851
851
// complete.
852
852
func (c * Conn ) Close (code StatusCode , reason string ) error {
853
853
err := c .exportedClose (code , reason , true )
854
+ var ec errClosing
855
+ if errors .As (err , & ec ) {
856
+ <- c .closed
857
+ // We wait until the connection closes.
858
+ // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error.
859
+ err = c .writeClose (nil , ec .ce , true )
860
+ }
854
861
if err != nil {
855
862
return fmt .Errorf ("failed to close websocket connection: %w" , err )
856
863
}
@@ -878,15 +885,31 @@ func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) err
878
885
return c .writeClose (p , fmt .Errorf ("sent close: %w" , ce ), handshake )
879
886
}
880
887
888
+ type errClosing struct {
889
+ ce error
890
+ }
891
+
892
+ func (e errClosing ) Error () string {
893
+ return "already closing connection"
894
+ }
895
+
881
896
func (c * Conn ) writeClose (p []byte , ce error , handshake bool ) error {
882
- select {
883
- case <- c .closed :
884
- return fmt .Errorf ("tried to close with %v but connection already closed: %w" , ce , c .closeErr )
885
- default :
897
+ if c .isClosed () {
898
+ return fmt .Errorf ("tried to close with %q but connection already closed: %w" , ce , c .closeErr )
886
899
}
887
900
888
901
if ! c .closing .CAS (0 , 1 ) {
889
- return fmt .Errorf ("another goroutine is closing" )
902
+ // Normally, we would want to wait until the connection is closed,
903
+ // at least for when a user calls into Close, so we handle that case in
904
+ // the exported Close function.
905
+ //
906
+ // But for internal library usage, we always want to return early, e.g.
907
+ // if we are performing a close handshake and the peer sends their close frame,
908
+ // we do not want to block here waiting for c.closed to close because it won't,
909
+ // at least not until we return since the gorouine that will close it is this one.
910
+ return errClosing {
911
+ ce : ce ,
912
+ }
890
913
}
891
914
892
915
// No matter what happens next, close error should be set.
0 commit comments