Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io/ioutil"
"math/rand"
"net"
"reflect"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -219,7 +220,7 @@ var validReceivedCloseCodes = map[int]bool{
CloseTLSHandshake: false,
}

func isValidReceivedCloseCode(code int) bool {
func isValidCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
}

Expand Down Expand Up @@ -325,10 +326,53 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}

// Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error {
return c.conn.Close()
// Close sends close frame and waits for one in response
// it expects two args. `closeCode int` and `closeMessage string` in order
// it uses variadic args to maintain backwards compatibility
func (c *Conn) Close(args ...interface{}) error {
closeCode := CloseNoStatusReceived
message := ""
ok := false
if len(args) == 2 {
closeCode, ok = args[0].(int)
if !ok {
closeCode = CloseNoStatusReceived
}
message, ok = args[1].(string)
if !ok {
message = ""
}
}
err := c.Shutdown(closeCode, message)
if err != nil {
return err
}
c.conn.Close()
return nil
}

// Shutdown sends a close frame and waits for one in response
func (c *Conn) Shutdown(closeCode int, closeMessage string) error {
if !isValidCloseCode(closeCode) {
// we do not shutdown connection
return errors.New("invalid close code received")
}
if !utf8.ValidString(closeMessage) {
return errors.New("invalid utf8 payload for shutdown message")
}
message := FormatCloseMessage(closeCode, closeMessage)
err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
if err != nil {
return err
}
timeStart := time.Now()
c.conn.SetReadDeadline(time.Now().Add(time.Minute))
for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; {
if timeStart.Sub(time.Now()) > time.Minute {
break
}
}
return nil
}

// LocalAddr returns the local network address.
Expand Down Expand Up @@ -496,6 +540,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {

var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
Expand Down Expand Up @@ -902,7 +947,7 @@ func (c *Conn) advanceFrame() (int, error) {
closeText := ""
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
if !isValidCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
}
closeText = string(payload[2:])
Expand Down