@@ -12,72 +12,60 @@ import (
12
12
"golang.org/x/xerrors"
13
13
)
14
14
15
- // AcceptOption is an option that can be passed to Accept.
16
- // The implementations of this interface are printable.
17
- type AcceptOption interface {
18
- acceptOption ()
19
- }
20
-
21
- type acceptSubprotocols []string
22
-
23
- func (o acceptSubprotocols ) acceptOption () {}
24
-
25
- // AcceptSubprotocols lists the websocket subprotocols that Accept will negotiate with a client.
26
- // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
27
- // reject it, close the connection if c.Subprotocol() == "".
28
- func AcceptSubprotocols (protocols ... string ) AcceptOption {
29
- return acceptSubprotocols (protocols )
30
- }
31
-
32
- type acceptInsecureOrigin struct {}
33
-
34
- func (o acceptInsecureOrigin ) acceptOption () {}
35
-
36
- // AcceptInsecureOrigin disables Accept's origin verification
37
- // behaviour. By default Accept only allows the handshake to
38
- // succeed if the javascript that is initiating the handshake
39
- // is on the same domain as the server. This is to prevent CSRF
40
- // when secure data is stored in cookies.
41
- //
42
- // See https://stackoverflow.com/a/37837709/4283659
43
- //
44
- // Use this if you want a WebSocket server any javascript can
45
- // connect to or you want to perform Origin verification yourself
46
- // and allow some whitelist of domains.
47
- //
48
- // Ensure you understand exactly what the above means before you use
49
- // this option in conjugation with cookies containing secure data.
50
- func AcceptInsecureOrigin () AcceptOption {
51
- return acceptInsecureOrigin {}
15
+ // AcceptOptions represents the options available to pass to Accept.
16
+ type AcceptOptions struct {
17
+ // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client.
18
+ // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
19
+ // reject it, close the connection if c.Subprotocol() == "".
20
+ Subprotocols []string
21
+
22
+ // InsecureSkipVerify disables Accept's origin verification
23
+ // behaviour. By default Accept only allows the handshake to
24
+ // succeed if the javascript that is initiating the handshake
25
+ // is on the same domain as the server. This is to prevent CSRF
26
+ // when secure data is stored in a cookie as there is no same
27
+ // origin policy for WebSockets. In other words, javascript from
28
+ // any domain can perform a WebSocket dial on an arbitrary server.
29
+ // This dial will include cookies which means the arbitrary javascript
30
+ // can perform actions as the authenticated user.
31
+ //
32
+ // See https://stackoverflow.com/a/37837709/4283659
33
+ //
34
+ // The only time you need this is if your javascript is running on a different domain
35
+ // than your WebSocket server.
36
+ // Please think carefully about whether you really need this option before you use it.
37
+ // If you do, remember if you store secure data in cookies, you wil need to verify the
38
+ // Origin header.
39
+ InsecureSkipVerify bool
52
40
}
53
41
54
42
func verifyClientRequest (w http.ResponseWriter , r * http.Request ) error {
55
43
if ! headerValuesContainsToken (r .Header , "Connection" , "Upgrade" ) {
56
- err := xerrors .Errorf ("websocket: protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
44
+ err := xerrors .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
57
45
http .Error (w , err .Error (), http .StatusBadRequest )
58
46
return err
59
47
}
60
48
61
49
if ! headerValuesContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
62
- err := xerrors .Errorf ("websocket: protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
50
+ err := xerrors .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
63
51
http .Error (w , err .Error (), http .StatusBadRequest )
64
52
return err
65
53
}
66
54
67
55
if r .Method != "GET" {
68
- err := xerrors .Errorf ("websocket: protocol violation: handshake request method %q is not GET" , r .Method )
56
+ err := xerrors .Errorf ("websocket protocol violation: handshake request method %q is not GET" , r .Method )
69
57
http .Error (w , err .Error (), http .StatusBadRequest )
70
58
return err
71
59
}
72
60
73
61
if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
74
- err := xerrors .Errorf ("websocket: unsupported protocol version: %q" , r .Header .Get ("Sec-WebSocket-Version" ))
62
+ err := xerrors .Errorf ("unsupported websocket protocol version: %q" , r .Header .Get ("Sec-WebSocket-Version" ))
75
63
http .Error (w , err .Error (), http .StatusBadRequest )
76
64
return err
77
65
}
78
66
79
67
if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
80
- err := xerrors .New ("websocket: protocol violation: missing Sec-WebSocket-Key" )
68
+ err := xerrors .New ("websocket protocol violation: missing Sec-WebSocket-Key" )
81
69
http .Error (w , err .Error (), http .StatusBadRequest )
82
70
return err
83
71
}
@@ -88,26 +76,22 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
88
76
// Accept accepts a WebSocket handshake from a client and upgrades the
89
77
// the connection to WebSocket.
90
78
// Accept will reject the handshake if the Origin is not the same as the Host unless
91
- // the AcceptInsecureOrigin option is passed.
92
- // Accept uses w to write the handshake response so the timeouts on the http.Server apply.
93
- func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
94
- var subprotocols []string
95
- verifyOrigin := true
96
- for _ , opt := range opts {
97
- switch opt := opt .(type ) {
98
- case acceptInsecureOrigin :
99
- verifyOrigin = false
100
- case acceptSubprotocols :
101
- subprotocols = []string (opt )
102
- }
79
+ // the InsecureSkipVerify option is set.
80
+ func Accept (w http.ResponseWriter , r * http.Request , opts AcceptOptions ) (* Conn , error ) {
81
+ c , err := accept (w , r , opts )
82
+ if err != nil {
83
+ return nil , xerrors .Errorf ("failed to accept websocket connection: %w" , err )
103
84
}
85
+ return c , nil
86
+ }
104
87
88
+ func accept (w http.ResponseWriter , r * http.Request , opts AcceptOptions ) (* Conn , error ) {
105
89
err := verifyClientRequest (w , r )
106
90
if err != nil {
107
91
return nil , err
108
92
}
109
93
110
- if verifyOrigin {
94
+ if ! opts . InsecureSkipVerify {
111
95
err = authenticateOrigin (r )
112
96
if err != nil {
113
97
http .Error (w , err .Error (), http .StatusForbidden )
@@ -117,7 +101,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
117
101
118
102
hj , ok := w .(http.Hijacker )
119
103
if ! ok {
120
- err = xerrors .New ("websocket: response writer does not implement http.Hijacker" )
104
+ err = xerrors .New ("response writer must implement http.Hijacker" )
121
105
http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
122
106
return nil , err
123
107
}
@@ -127,7 +111,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
127
111
128
112
handleKey (w , r )
129
113
130
- subproto := selectSubprotocol (r , subprotocols )
114
+ subproto := selectSubprotocol (r , opts . Subprotocols )
131
115
if subproto != "" {
132
116
w .Header ().Set ("Sec-WebSocket-Protocol" , subproto )
133
117
}
@@ -136,7 +120,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
136
120
137
121
netConn , brw , err := hj .Hijack ()
138
122
if err != nil {
139
- err = xerrors .Errorf ("websocket: failed to hijack connection: %w" , err )
123
+ err = xerrors .Errorf ("failed to hijack connection: %w" , err )
140
124
http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
141
125
return nil , err
142
126
}
@@ -190,5 +174,5 @@ func authenticateOrigin(r *http.Request) error {
190
174
if strings .EqualFold (u .Host , r .Host ) {
191
175
return nil
192
176
}
193
- return xerrors .Errorf ("request origin %q is not authorized" , origin )
177
+ return xerrors .Errorf ("request origin %q is not authorized for host %q " , origin , r . Host )
194
178
}
0 commit comments