Skip to content

Commit 2436baf

Browse files
Add support for wrapping an existing WS connection
Signed-off-by: Peter Broadhurst <peter.broadhurst@kaleido.io>
1 parent 698d96a commit 2436baf

File tree

3 files changed

+103
-8
lines changed

3 files changed

+103
-8
lines changed

pkg/wsclient/wsclient.go

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ type WSConfig struct {
5959
ReceiveExt bool
6060
}
6161

62+
type WSWrapConfig struct {
63+
HeartbeatInterval time.Duration `json:"heartbeatInterval,omitempty"`
64+
ThrottleRequestsPerSecond int `json:"requestsPerSecond,omitempty"`
65+
ThrottleBurst int `json:"burst,omitempty"`
66+
// This one cannot be set in JSON - must be configured on the code interface
67+
ReceiveExt bool
68+
}
69+
6270
// WSPayload allows API consumers of this package to stream data, and inspect the message
6371
// type, rather than just being passed the bytes directly.
6472
type WSPayload struct {
@@ -98,7 +106,7 @@ type wsClient struct {
98106
initialRetryAttempts int
99107
wsdialer *websocket.Dialer
100108
wsconn *websocket.Conn
101-
retry retry.Retry
109+
connRetry retry.Retry
102110
closed bool
103111
useReceiveExt bool
104112
receive chan []byte
@@ -122,6 +130,7 @@ type WSPreConnectHandler func(ctx context.Context, w WSClient) error
122130
// WSPostConnectHandler will be called after every connect/reconnect. Can send data over ws, but must not block listening for data on the ws.
123131
type WSPostConnectHandler func(ctx context.Context, w WSClient) error
124132

133+
// Creates a new outbound client that can be connected to a remote server
125134
func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandler, afterConnect WSPostConnectHandler) (WSClient, error) {
126135
l := log.L(ctx)
127136
wsURL, err := buildWSUrl(ctx, config)
@@ -138,7 +147,7 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
138147
TLSClientConfig: config.TLSClientConfig,
139148
HandshakeTimeout: config.ConnectionTimeout,
140149
},
141-
retry: retry.Retry{
150+
connRetry: retry.Retry{
142151
InitialDelay: config.InitialDelay,
143152
MaximumDelay: config.MaximumDelay,
144153
},
@@ -153,11 +162,7 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
153162
disableReconnect: config.DisableReconnect,
154163
rateLimiter: ffresty.GetRateLimiter(config.ThrottleRequestsPerSecond, config.ThrottleBurst),
155164
}
156-
if w.useReceiveExt {
157-
w.receiveExt = make(chan *WSPayload)
158-
} else {
159-
w.receive = make(chan []byte)
160-
}
165+
w.setupReceiveChannel()
161166
for k, v := range config.HTTPHeaders {
162167
if vs, ok := v.(string); ok {
163168
w.headers.Set(k, vs)
@@ -182,6 +187,40 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
182187
return w, nil
183188
}
184189

190+
// Wrap an existing connection (including an inbound server connection) with heartbeating and throttling.
191+
// No reconnect functions are supported when wrapping an existing connection like this, but the supplied
192+
// callback will be invoked when the connection closes (allowing cleanup/tracking).
193+
func Wrap(ctx context.Context, config WSWrapConfig, wsconn *websocket.Conn, onClose func()) WSClient {
194+
w := &wsClient{
195+
ctx: ctx,
196+
url: wsconn.LocalAddr().String(),
197+
wsconn: wsconn,
198+
disableReconnect: true,
199+
heartbeatInterval: config.HeartbeatInterval,
200+
rateLimiter: ffresty.GetRateLimiter(config.ThrottleRequestsPerSecond, config.ThrottleBurst),
201+
useReceiveExt: config.ReceiveExt,
202+
send: make(chan []byte),
203+
closing: make(chan struct{}),
204+
}
205+
w.setupReceiveChannel()
206+
w.pongReceivedOrReset(false)
207+
w.wsconn.SetPongHandler(w.pongHandler)
208+
log.L(ctx).Infof("WS %s wrapped", w.url)
209+
go func() {
210+
w.receiveReconnectLoop()
211+
onClose()
212+
}()
213+
return w
214+
}
215+
216+
func (w *wsClient) setupReceiveChannel() {
217+
if w.useReceiveExt {
218+
w.receiveExt = make(chan *WSPayload)
219+
} else {
220+
w.receive = make(chan []byte)
221+
}
222+
}
223+
185224
func (w *wsClient) Connect() error {
186225

187226
if err := w.connect(true); err != nil {
@@ -291,7 +330,7 @@ func buildWSUrl(ctx context.Context, config *WSConfig) (string, error) {
291330

292331
func (w *wsClient) connect(initial bool) error {
293332
l := log.L(w.ctx)
294-
return w.retry.DoCustomLog(w.ctx, func(attempt int) (retry bool, err error) {
333+
return w.connRetry.DoCustomLog(w.ctx, func(attempt int) (retry bool, err error) {
295334
if w.closed {
296335
return false, i18n.NewError(w.ctx, i18n.MsgWSClosing)
297336
}
@@ -436,6 +475,7 @@ func (w *wsClient) sendLoop(receiverDone chan struct{}) {
436475
l.Errorf("WS %s closing: %s", w.url, err)
437476
disconnecting = true
438477
} else if wsconn != nil {
478+
l.Debugf("WS %s send heartbeat ping", w.url)
439479
if err := wsconn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
440480
l.Errorf("WS %s heartbeat send failed: %s", w.url, err)
441481
disconnecting = true

pkg/wsclient/wsclient_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ import (
3030
"time"
3131

3232
"github.com/gorilla/websocket"
33+
"github.com/sirupsen/logrus"
3334
"github.com/stretchr/testify/assert"
35+
"github.com/stretchr/testify/require"
3436
"golang.org/x/time/rate"
3537
)
3638

@@ -838,3 +840,51 @@ func TestRateLimiterFailure(t *testing.T) {
838840
// Close the client
839841
wsc.Close()
840842
}
843+
844+
func TestWSWrap(t *testing.T) {
845+
ctx := context.Background()
846+
logrus.SetLevel(logrus.DebugLevel)
847+
848+
passWS := make(chan (*websocket.Conn))
849+
svr := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
850+
upgrader := &websocket.Upgrader{WriteBufferSize: 1024, ReadBufferSize: 1024}
851+
ws, err := upgrader.Upgrade(res, req, http.Header{})
852+
require.NoError(t, err)
853+
passWS <- ws
854+
}))
855+
defer svr.Close()
856+
857+
clientDone := make(chan struct{})
858+
go func() {
859+
defer close(clientDone)
860+
861+
wsc, err := New(ctx, &WSConfig{HTTPURL: svr.URL}, nil, nil)
862+
require.NoError(t, err)
863+
err = wsc.Connect()
864+
require.NoError(t, err)
865+
866+
wsc.Send(ctx, []byte(`hello`))
867+
msg1 := <-wsc.Receive()
868+
require.Equal(t, `hi`, string(msg1))
869+
870+
wsc.Close()
871+
872+
}()
873+
874+
// Get the conn
875+
rawWSC := <-passWS
876+
877+
// Wrap it
878+
serverDone := make(chan struct{})
879+
wsc := Wrap(ctx, WSWrapConfig{}, rawWSC, func() {
880+
close(serverDone)
881+
})
882+
883+
msg1 := <-wsc.Receive()
884+
require.Equal(t, `hello`, string(msg1))
885+
err := wsc.Send(ctx, []byte(`hi`))
886+
require.NoError(t, err)
887+
888+
<-clientDone
889+
<-serverDone
890+
}

pkg/wsclient/wsconfig.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ func InitConfig(conf config.Section) {
6262
conf.AddKnownKey(WSConfigURL)
6363
conf.AddKnownKey(WSConfigKeyHeartbeatInterval, defaultHeartbeatInterval)
6464
conf.AddKnownKey(WSConfigKeyConnectionTimeout, defaultConnectionTimeout)
65+
InitConfigWrap(conf)
66+
}
67+
68+
func InitConfigWrap(conf config.Section) {
69+
conf.AddKnownKey(WSConfigKeyHeartbeatInterval, defaultHeartbeatInterval)
6570
}
6671

6772
func GenerateConfig(ctx context.Context, conf config.Section) (*WSConfig, error) {

0 commit comments

Comments
 (0)