@@ -5,6 +5,7 @@ package websocket
55
66import (
77 "bytes"
8+ "context"
89 "crypto/sha1"
910 "encoding/base64"
1011 "errors"
@@ -14,7 +15,7 @@ import (
1415 "net/http"
1516 "net/textproto"
1617 "net/url"
17- "path/filepath "
18+ "path"
1819 "strings"
1920
2021 "github.com/coder/websocket/internal/errd"
@@ -41,8 +42,8 @@ type AcceptOptions struct {
4142 // One would set this field to []string{"example.com"} to authorize example.com to connect.
4243 //
4344 // Each pattern is matched case insensitively against the request origin host
44- // with filepath .Match.
45- // See https://golang.org/pkg/path/filepath/ #Match
45+ // with path .Match.
46+ // See https://golang.org/pkg/path/#Match
4647 //
4748 // Please ensure you understand the ramifications of enabling this.
4849 // If used incorrectly your WebSocket server will be open to CSRF attacks.
@@ -62,6 +63,22 @@ type AcceptOptions struct {
6263 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
6364 // for CompressionContextTakeover.
6465 CompressionThreshold int
66+
67+ // OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
68+ //
69+ // The payload contains the application data of the ping frame.
70+ // If the callback returns false, the subsequent pong frame will not be sent.
71+ // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
72+ OnPingReceived func (ctx context.Context , payload []byte ) bool
73+
74+ // OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
75+ //
76+ // The payload contains the application data of the pong frame.
77+ // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
78+ //
79+ // Unlike OnPingReceived, this callback does not return a value because a pong frame
80+ // is a response to a ping and does not trigger any further frame transmission.
81+ OnPongReceived func (ctx context.Context , payload []byte )
6582}
6683
6784func (opts * AcceptOptions ) cloneWithDefaults () * AcceptOptions {
@@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
7996// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
8097//
8198// Accept will write a response to w on all errors.
99+ //
100+ // Note that using the http.Request Context after Accept returns may lead to
101+ // unexpected behavior (see http.Hijacker).
82102func Accept (w http.ResponseWriter , r * http.Request , opts * AcceptOptions ) (* Conn , error ) {
83103 return accept (w , r , opts )
84104}
@@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
96116 if ! opts .InsecureSkipVerify {
97117 err = authenticateOrigin (r , opts .OriginPatterns )
98118 if err != nil {
99- if errors .Is (err , filepath .ErrBadPattern ) {
119+ if errors .Is (err , path .ErrBadPattern ) {
100120 log .Printf ("websocket: %v" , err )
101121 err = errors .New (http .StatusText (http .StatusForbidden ))
102122 }
@@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
105125 }
106126 }
107127
108- hj , ok := w .(http. Hijacker )
128+ hj , ok := hijacker ( w )
109129 if ! ok {
110130 err = errors .New ("http.ResponseWriter does not implement http.Hijacker" )
111131 http .Error (w , http .StatusText (http .StatusNotImplemented ), http .StatusNotImplemented )
@@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
153173 client : false ,
154174 copts : copts ,
155175 flateThreshold : opts .CompressionThreshold ,
176+ onPingReceived : opts .OnPingReceived ,
177+ onPongReceived : opts .OnPongReceived ,
156178
157179 br : brw .Reader ,
158180 bw : brw .Writer ,
@@ -221,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
221243 for _ , hostPattern := range originHosts {
222244 matched , err := match (hostPattern , u .Host )
223245 if err != nil {
224- return fmt .Errorf ("failed to parse filepath pattern %q: %w" , hostPattern , err )
246+ return fmt .Errorf ("failed to parse path pattern %q: %w" , hostPattern , err )
225247 }
226248 if matched {
227249 return nil
@@ -234,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
234256}
235257
236258func match (pattern , s string ) (bool , error ) {
237- return filepath .Match (strings .ToLower (pattern ), strings .ToLower (s ))
259+ return path .Match (strings .ToLower (pattern ), strings .ToLower (s ))
238260}
239261
240262func selectSubprotocol (r * http.Request , subprotocols []string ) string {
0 commit comments