@@ -18,11 +18,15 @@ package rpc
18
18
19
19
import (
20
20
"context"
21
+ "io"
21
22
"net"
22
23
"net/http"
23
24
"net/http/httptest"
25
+ "net/http/httputil"
26
+ "net/url"
24
27
"reflect"
25
28
"strings"
29
+ "sync/atomic"
26
30
"testing"
27
31
"time"
28
32
@@ -188,6 +192,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
188
192
}
189
193
}
190
194
195
+ func TestClientWebsocketSevered (t * testing.T ) {
196
+ t .Parallel ()
197
+
198
+ var (
199
+ server = wsPingTestServer (t , nil )
200
+ ctx = context .Background ()
201
+ )
202
+ defer server .Shutdown (ctx )
203
+
204
+ u , err := url .Parse ("http://" + server .Addr )
205
+ if err != nil {
206
+ t .Fatal (err )
207
+ }
208
+ rproxy := httputil .NewSingleHostReverseProxy (u )
209
+ var severable * severableReadWriteCloser
210
+ rproxy .ModifyResponse = func (response * http.Response ) error {
211
+ severable = & severableReadWriteCloser {ReadWriteCloser : response .Body .(io.ReadWriteCloser )}
212
+ response .Body = severable
213
+ return nil
214
+ }
215
+ frontendProxy := httptest .NewServer (rproxy )
216
+ defer frontendProxy .Close ()
217
+
218
+ wsURL := "ws:" + strings .TrimPrefix (frontendProxy .URL , "http:" )
219
+ client , err := DialWebsocket (ctx , wsURL , "" )
220
+ if err != nil {
221
+ t .Fatalf ("client dial error: %v" , err )
222
+ }
223
+ defer client .Close ()
224
+
225
+ resultChan := make (chan int )
226
+ sub , err := client .EthSubscribe (ctx , resultChan , "foo" )
227
+ if err != nil {
228
+ t .Fatalf ("client subscribe error: %v" , err )
229
+ }
230
+
231
+ // sever the connection
232
+ severable .Sever ()
233
+
234
+ // Wait for subscription error.
235
+ timeout := time .NewTimer (3 * wsPingInterval )
236
+ defer timeout .Stop ()
237
+ for {
238
+ select {
239
+ case err := <- sub .Err ():
240
+ t .Log ("client subscription error:" , err )
241
+ return
242
+ case result := <- resultChan :
243
+ t .Error ("unexpected result:" , result )
244
+ return
245
+ case <- timeout .C :
246
+ t .Error ("didn't get any error within the test timeout" )
247
+ return
248
+ }
249
+ }
250
+ }
251
+
191
252
// wsPingTestServer runs a WebSocket server which accepts a single subscription request.
192
253
// When a value arrives on sendPing, the server sends a ping frame, waits for a matching
193
254
// pong and finally delivers a single subscription result.
@@ -290,3 +351,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-
290
351
}
291
352
}
292
353
}
354
+
355
+ // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty.
356
+ type severableReadWriteCloser struct {
357
+ io.ReadWriteCloser
358
+ severed int32 // atomic
359
+ }
360
+
361
+ func (s * severableReadWriteCloser ) Sever () {
362
+ atomic .StoreInt32 (& s .severed , 1 )
363
+ }
364
+
365
+ func (s * severableReadWriteCloser ) Read (p []byte ) (n int , err error ) {
366
+ if atomic .LoadInt32 (& s .severed ) > 0 {
367
+ return 0 , nil
368
+ }
369
+ return s .ReadWriteCloser .Read (p )
370
+ }
371
+
372
+ func (s * severableReadWriteCloser ) Write (p []byte ) (n int , err error ) {
373
+ if atomic .LoadInt32 (& s .severed ) > 0 {
374
+ return len (p ), nil
375
+ }
376
+ return s .ReadWriteCloser .Write (p )
377
+ }
378
+
379
+ func (s * severableReadWriteCloser ) Close () error {
380
+ return s .ReadWriteCloser .Close ()
381
+ }
0 commit comments