Skip to content

Commit 51ececb

Browse files
authored
rpc: set pong read deadline (#23556)
This PR adds a 30s timeout for the remote part to answer a ping message, thus detecting (silent) disconnnects
1 parent 57a3fab commit 51ececb

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

rpc/websocket.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
wsWriteBuffer = 1024
3838
wsPingInterval = 60 * time.Second
3939
wsPingWriteTimeout = 5 * time.Second
40+
wsPongTimeout = 30 * time.Second
4041
wsMessageSizeLimit = 15 * 1024 * 1024
4142
)
4243

@@ -241,6 +242,10 @@ type websocketCodec struct {
241242

242243
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
243244
conn.SetReadLimit(wsMessageSizeLimit)
245+
conn.SetPongHandler(func(appData string) error {
246+
conn.SetReadDeadline(time.Time{})
247+
return nil
248+
})
244249
wc := &websocketCodec{
245250
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
246251
conn: conn,
@@ -287,6 +292,7 @@ func (wc *websocketCodec) pingLoop() {
287292
wc.jsonCodec.encMu.Lock()
288293
wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
289294
wc.conn.WriteMessage(websocket.PingMessage, nil)
295+
wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout))
290296
wc.jsonCodec.encMu.Unlock()
291297
timer.Reset(wsPingInterval)
292298
}

rpc/websocket_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ package rpc
1818

1919
import (
2020
"context"
21+
"io"
2122
"net"
2223
"net/http"
2324
"net/http/httptest"
25+
"net/http/httputil"
26+
"net/url"
2427
"reflect"
2528
"strings"
29+
"sync/atomic"
2630
"testing"
2731
"time"
2832

@@ -188,6 +192,63 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
188192
}
189193
}
190194

195+
func TestClientWebsocketSevered(t *testing.T) {
196+
t.Parallel()
197+
198+
var (
199+
server = wsPingTestServer(t, nil)
200+
ctx = context.Background()
201+
)
202+
defer server.Shutdown(ctx)
203+
204+
u, err := url.Parse("http://" + server.Addr)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
rproxy := httputil.NewSingleHostReverseProxy(u)
209+
var severable *severableReadWriteCloser
210+
rproxy.ModifyResponse = func(response *http.Response) error {
211+
severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)}
212+
response.Body = severable
213+
return nil
214+
}
215+
frontendProxy := httptest.NewServer(rproxy)
216+
defer frontendProxy.Close()
217+
218+
wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:")
219+
client, err := DialWebsocket(ctx, wsURL, "")
220+
if err != nil {
221+
t.Fatalf("client dial error: %v", err)
222+
}
223+
defer client.Close()
224+
225+
resultChan := make(chan int)
226+
sub, err := client.EthSubscribe(ctx, resultChan, "foo")
227+
if err != nil {
228+
t.Fatalf("client subscribe error: %v", err)
229+
}
230+
231+
// sever the connection
232+
severable.Sever()
233+
234+
// Wait for subscription error.
235+
timeout := time.NewTimer(3 * wsPingInterval)
236+
defer timeout.Stop()
237+
for {
238+
select {
239+
case err := <-sub.Err():
240+
t.Log("client subscription error:", err)
241+
return
242+
case result := <-resultChan:
243+
t.Error("unexpected result:", result)
244+
return
245+
case <-timeout.C:
246+
t.Error("didn't get any error within the test timeout")
247+
return
248+
}
249+
}
250+
}
251+
191252
// wsPingTestServer runs a WebSocket server which accepts a single subscription request.
192253
// When a value arrives on sendPing, the server sends a ping frame, waits for a matching
193254
// pong and finally delivers a single subscription result.
@@ -290,3 +351,31 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-
290351
}
291352
}
292353
}
354+
355+
// severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty.
356+
type severableReadWriteCloser struct {
357+
io.ReadWriteCloser
358+
severed int32 // atomic
359+
}
360+
361+
func (s *severableReadWriteCloser) Sever() {
362+
atomic.StoreInt32(&s.severed, 1)
363+
}
364+
365+
func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) {
366+
if atomic.LoadInt32(&s.severed) > 0 {
367+
return 0, nil
368+
}
369+
return s.ReadWriteCloser.Read(p)
370+
}
371+
372+
func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) {
373+
if atomic.LoadInt32(&s.severed) > 0 {
374+
return len(p), nil
375+
}
376+
return s.ReadWriteCloser.Write(p)
377+
}
378+
379+
func (s *severableReadWriteCloser) Close() error {
380+
return s.ReadWriteCloser.Close()
381+
}

0 commit comments

Comments
 (0)