@@ -18,11 +18,15 @@ package rpc
1818
1919import (
2020 "context"
21+ "io"
2122 "net"
2223 "net/http"
2324 "net/http/httptest"
25+ "net/http/httputil"
26+ "net/url"
2427 "reflect"
2528 "strings"
29+ "sync/atomic"
2630 "testing"
2731 "time"
2832
@@ -188,6 +192,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
188192 }
189193}
190194
195+ func TestClientWebsocketSevered (t * testing.T ) {
196+ t .Parallel ()
197+
198+ var (
199+ server = wsPingTestServer (t , nil )
200+ ctx = context .Background ()
201+ )
202+ defer server .Shutdown (ctx )
203+
204+ u , err := url .Parse ("http://" + server .Addr )
205+ if err != nil {
206+ t .Fatal (err )
207+ }
208+ rproxy := httputil .NewSingleHostReverseProxy (u )
209+ var severable * severableReadWriteCloser
210+ rproxy .ModifyResponse = func (response * http.Response ) error {
211+ severable = & severableReadWriteCloser {ReadWriteCloser : response .Body .(io.ReadWriteCloser )}
212+ response .Body = severable
213+ return nil
214+ }
215+ frontendProxy := httptest .NewServer (rproxy )
216+ defer frontendProxy .Close ()
217+
218+ wsURL := "ws:" + strings .TrimPrefix (frontendProxy .URL , "http:" )
219+ client , err := DialWebsocket (ctx , wsURL , "" )
220+ if err != nil {
221+ t .Fatalf ("client dial error: %v" , err )
222+ }
223+ defer client .Close ()
224+
225+ resultChan := make (chan int )
226+ sub , err := client .EthSubscribe (ctx , resultChan , "foo" )
227+ if err != nil {
228+ t .Fatalf ("client subscribe error: %v" , err )
229+ }
230+
231+ // sever the connection
232+ severable .Sever ()
233+
234+ // Wait for subscription error.
235+ timeout := time .NewTimer (3 * wsPingInterval )
236+ defer timeout .Stop ()
237+ for {
238+ select {
239+ case err := <- sub .Err ():
240+ t .Log ("client subscription error:" , err )
241+ return
242+ case result := <- resultChan :
243+ t .Error ("unexpected result:" , result )
244+ return
245+ case <- timeout .C :
246+ t .Error ("didn't get any error within the test timeout" )
247+ return
248+ }
249+ }
250+ }
251+
191252// wsPingTestServer runs a WebSocket server which accepts a single subscription request.
192253// When a value arrives on sendPing, the server sends a ping frame, waits for a matching
193254// pong and finally delivers a single subscription result.
@@ -290,3 +351,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-
290351 }
291352 }
292353}
354+
355+ // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty.
356+ type severableReadWriteCloser struct {
357+ io.ReadWriteCloser
358+ severed int32 // atomic
359+ }
360+
361+ func (s * severableReadWriteCloser ) Sever () {
362+ atomic .StoreInt32 (& s .severed , 1 )
363+ }
364+
365+ func (s * severableReadWriteCloser ) Read (p []byte ) (n int , err error ) {
366+ if atomic .LoadInt32 (& s .severed ) > 0 {
367+ return 0 , nil
368+ }
369+ return s .ReadWriteCloser .Read (p )
370+ }
371+
372+ func (s * severableReadWriteCloser ) Write (p []byte ) (n int , err error ) {
373+ if atomic .LoadInt32 (& s .severed ) > 0 {
374+ return len (p ), nil
375+ }
376+ return s .ReadWriteCloser .Write (p )
377+ }
378+
379+ func (s * severableReadWriteCloser ) Close () error {
380+ return s .ReadWriteCloser .Close ()
381+ }
0 commit comments