@@ -5,6 +5,7 @@ package websocket
5
5
6
6
import (
7
7
"bytes"
8
+ "context"
8
9
"crypto/sha1"
9
10
"encoding/base64"
10
11
"errors"
@@ -14,7 +15,7 @@ import (
14
15
"net/http"
15
16
"net/textproto"
16
17
"net/url"
17
- "path/filepath "
18
+ "path"
18
19
"strings"
19
20
20
21
"github.com/coder/websocket/internal/errd"
@@ -41,8 +42,8 @@ type AcceptOptions struct {
41
42
// One would set this field to []string{"example.com"} to authorize example.com to connect.
42
43
//
43
44
// Each pattern is matched case insensitively against the request origin host
44
- // with filepath .Match.
45
- // See https://golang.org/pkg/path/filepath/ #Match
45
+ // with path .Match.
46
+ // See https://golang.org/pkg/path/#Match
46
47
//
47
48
// Please ensure you understand the ramifications of enabling this.
48
49
// If used incorrectly your WebSocket server will be open to CSRF attacks.
@@ -62,6 +63,22 @@ type AcceptOptions struct {
62
63
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
63
64
// for CompressionContextTakeover.
64
65
CompressionThreshold int
66
+
67
+ // OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
68
+ //
69
+ // The payload contains the application data of the ping frame.
70
+ // If the callback returns false, the subsequent pong frame will not be sent.
71
+ // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
72
+ OnPingReceived func (ctx context.Context , payload []byte ) bool
73
+
74
+ // OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
75
+ //
76
+ // The payload contains the application data of the pong frame.
77
+ // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
78
+ //
79
+ // Unlike OnPingReceived, this callback does not return a value because a pong frame
80
+ // is a response to a ping and does not trigger any further frame transmission.
81
+ OnPongReceived func (ctx context.Context , payload []byte )
65
82
}
66
83
67
84
func (opts * AcceptOptions ) cloneWithDefaults () * AcceptOptions {
@@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
79
96
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
80
97
//
81
98
// Accept will write a response to w on all errors.
99
+ //
100
+ // Note that using the http.Request Context after Accept returns may lead to
101
+ // unexpected behavior (see http.Hijacker).
82
102
func Accept (w http.ResponseWriter , r * http.Request , opts * AcceptOptions ) (* Conn , error ) {
83
103
return accept (w , r , opts )
84
104
}
@@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
96
116
if ! opts .InsecureSkipVerify {
97
117
err = authenticateOrigin (r , opts .OriginPatterns )
98
118
if err != nil {
99
- if errors .Is (err , filepath .ErrBadPattern ) {
119
+ if errors .Is (err , path .ErrBadPattern ) {
100
120
log .Printf ("websocket: %v" , err )
101
121
err = errors .New (http .StatusText (http .StatusForbidden ))
102
122
}
@@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
105
125
}
106
126
}
107
127
108
- hj , ok := w .(http. Hijacker )
128
+ hj , ok := hijacker ( w )
109
129
if ! ok {
110
130
err = errors .New ("http.ResponseWriter does not implement http.Hijacker" )
111
131
http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
@@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
153
173
client : false ,
154
174
copts : copts ,
155
175
flateThreshold : opts .CompressionThreshold ,
176
+ onPingReceived : opts .OnPingReceived ,
177
+ onPongReceived : opts .OnPongReceived ,
156
178
157
179
br : brw .Reader ,
158
180
bw : brw .Writer ,
@@ -221,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
221
243
for _ , hostPattern := range originHosts {
222
244
matched , err := match (hostPattern , u .Host )
223
245
if err != nil {
224
- return fmt .Errorf ("failed to parse filepath pattern %q: %w" , hostPattern , err )
246
+ return fmt .Errorf ("failed to parse path pattern %q: %w" , hostPattern , err )
225
247
}
226
248
if matched {
227
249
return nil
@@ -234,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
234
256
}
235
257
236
258
func match (pattern , s string ) (bool , error ) {
237
- return filepath .Match (strings .ToLower (pattern ), strings .ToLower (s ))
259
+ return path .Match (strings .ToLower (pattern ), strings .ToLower (s ))
238
260
}
239
261
240
262
func selectSubprotocol (r * http.Request , subprotocols []string ) string {
0 commit comments