Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions pkg/wsclient/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type WSConfig struct {
WriteBufferSize int `json:"writeBufferSize,omitempty"`
InitialDelay time.Duration `json:"initialDelay,omitempty"`
MaximumDelay time.Duration `json:"maximumDelay,omitempty"`
DelayFactor float64 `json:"delayFactor,omitempty"`
InitialConnectAttempts int `json:"initialConnectAttempts,omitempty"`
DisableReconnect bool `json:"disableReconnect"`
AuthUsername string `json:"authUsername,omitempty"`
Expand All @@ -59,6 +60,14 @@ type WSConfig struct {
ReceiveExt bool
}

type WSWrapConfig struct {
HeartbeatInterval time.Duration `json:"heartbeatInterval,omitempty"`
ThrottleRequestsPerSecond int `json:"requestsPerSecond,omitempty"`
ThrottleBurst int `json:"burst,omitempty"`
// This one cannot be set in JSON - must be configured on the code interface
ReceiveExt bool
}

// WSPayload allows API consumers of this package to stream data, and inspect the message
// type, rather than just being passed the bytes directly.
type WSPayload struct {
Expand Down Expand Up @@ -98,7 +107,7 @@ type wsClient struct {
initialRetryAttempts int
wsdialer *websocket.Dialer
wsconn *websocket.Conn
retry retry.Retry
connRetry retry.Retry
closed bool
useReceiveExt bool
receive chan []byte
Expand All @@ -122,6 +131,7 @@ type WSPreConnectHandler func(ctx context.Context, w WSClient) error
// WSPostConnectHandler will be called after every connect/reconnect. Can send data over ws, but must not block listening for data on the ws.
type WSPostConnectHandler func(ctx context.Context, w WSClient) error

// Creates a new outbound client that can be connected to a remote server
func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandler, afterConnect WSPostConnectHandler) (WSClient, error) {
l := log.L(ctx)
wsURL, err := buildWSUrl(ctx, config)
Expand All @@ -138,9 +148,10 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
TLSClientConfig: config.TLSClientConfig,
HandshakeTimeout: config.ConnectionTimeout,
},
retry: retry.Retry{
connRetry: retry.Retry{
InitialDelay: config.InitialDelay,
MaximumDelay: config.MaximumDelay,
Factor: config.DelayFactor,
},
initialRetryAttempts: config.InitialConnectAttempts,
headers: make(http.Header),
Expand All @@ -153,11 +164,7 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
disableReconnect: config.DisableReconnect,
rateLimiter: ffresty.GetRateLimiter(config.ThrottleRequestsPerSecond, config.ThrottleBurst),
}
if w.useReceiveExt {
w.receiveExt = make(chan *WSPayload)
} else {
w.receive = make(chan []byte)
}
w.setupReceiveChannel()
for k, v := range config.HTTPHeaders {
if vs, ok := v.(string); ok {
w.headers.Set(k, vs)
Expand All @@ -182,6 +189,40 @@ func New(ctx context.Context, config *WSConfig, beforeConnect WSPreConnectHandle
return w, nil
}

// Wrap an existing connection (including an inbound server connection) with heartbeating and throttling.
// No reconnect functions are supported when wrapping an existing connection like this, but the supplied
// callback will be invoked when the connection closes (allowing cleanup/tracking).
func Wrap(ctx context.Context, config WSWrapConfig, wsconn *websocket.Conn, onClose func()) WSClient {
w := &wsClient{
ctx: ctx,
url: wsconn.LocalAddr().String(),
wsconn: wsconn,
disableReconnect: true,
heartbeatInterval: config.HeartbeatInterval,
rateLimiter: ffresty.GetRateLimiter(config.ThrottleRequestsPerSecond, config.ThrottleBurst),
useReceiveExt: config.ReceiveExt,
send: make(chan []byte),
closing: make(chan struct{}),
}
w.setupReceiveChannel()
w.pongReceivedOrReset(false)
w.wsconn.SetPongHandler(w.pongHandler)
log.L(ctx).Infof("WS %s wrapped", w.url)
go func() {
w.receiveReconnectLoop()
onClose()
}()
return w
}

func (w *wsClient) setupReceiveChannel() {
if w.useReceiveExt {
w.receiveExt = make(chan *WSPayload)
} else {
w.receive = make(chan []byte)
}
}

func (w *wsClient) Connect() error {

if err := w.connect(true); err != nil {
Expand Down Expand Up @@ -291,7 +332,7 @@ func buildWSUrl(ctx context.Context, config *WSConfig) (string, error) {

func (w *wsClient) connect(initial bool) error {
l := log.L(w.ctx)
return w.retry.DoCustomLog(w.ctx, func(attempt int) (retry bool, err error) {
return w.connRetry.DoCustomLog(w.ctx, func(attempt int) (retry bool, err error) {
if w.closed {
return false, i18n.NewError(w.ctx, i18n.MsgWSClosing)
}
Expand Down Expand Up @@ -436,6 +477,7 @@ func (w *wsClient) sendLoop(receiverDone chan struct{}) {
l.Errorf("WS %s closing: %s", w.url, err)
disconnecting = true
} else if wsconn != nil {
l.Debugf("WS %s send heartbeat ping", w.url)
if err := wsconn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
l.Errorf("WS %s heartbeat send failed: %s", w.url, err)
disconnecting = true
Expand Down
50 changes: 50 additions & 0 deletions pkg/wsclient/wsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
"time"

"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)

Expand Down Expand Up @@ -838,3 +840,51 @@ func TestRateLimiterFailure(t *testing.T) {
// Close the client
wsc.Close()
}

func TestWSWrap(t *testing.T) {
ctx := context.Background()
logrus.SetLevel(logrus.DebugLevel)

passWS := make(chan (*websocket.Conn))
svr := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
upgrader := &websocket.Upgrader{WriteBufferSize: 1024, ReadBufferSize: 1024}
ws, err := upgrader.Upgrade(res, req, http.Header{})
require.NoError(t, err)
passWS <- ws
}))
defer svr.Close()

clientDone := make(chan struct{})
go func() {
defer close(clientDone)

wsc, err := New(ctx, &WSConfig{HTTPURL: svr.URL}, nil, nil)
require.NoError(t, err)
err = wsc.Connect()
require.NoError(t, err)

wsc.Send(ctx, []byte(`hello`))
msg1 := <-wsc.Receive()
require.Equal(t, `hi`, string(msg1))

wsc.Close()

}()

// Get the conn
rawWSC := <-passWS

// Wrap it
serverDone := make(chan struct{})
wsc := Wrap(ctx, WSWrapConfig{}, rawWSC, func() {
close(serverDone)
})

msg1 := <-wsc.Receive()
require.Equal(t, `hello`, string(msg1))
err := wsc.Send(ctx, []byte(`hi`))
require.NoError(t, err)

<-clientDone
<-serverDone
}
10 changes: 10 additions & 0 deletions pkg/wsclient/wsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
defaultBufferSize = "16Kb"
defaultHeartbeatInterval = "30s" // up to a minute to detect a dead connection
defaultConnectionTimeout = 45 * time.Second // 45 seconds - the built in default for gorilla/websocket
defaultRetryBackoffFactor = 2.0
)

const (
Expand All @@ -49,6 +50,8 @@ const (
WSConfigKeyHeartbeatInterval = "ws.heartbeatInterval"
// WSConnectionTimeout is the amount of time to wait while attempting to establish a connection (or automatic reconnection)
WSConfigKeyConnectionTimeout = "ws.connectionTimeout"
// WSConfigDelayFactor the exponential backoff factor for delay
WSConfigDelayFactor = "retry.factor"
)

// InitConfig ensures the config is initialized for HTTP too, as WS and HTTP
Expand All @@ -62,6 +65,12 @@ func InitConfig(conf config.Section) {
conf.AddKnownKey(WSConfigURL)
conf.AddKnownKey(WSConfigKeyHeartbeatInterval, defaultHeartbeatInterval)
conf.AddKnownKey(WSConfigKeyConnectionTimeout, defaultConnectionTimeout)
conf.AddKnownKey(WSConfigDelayFactor, defaultRetryBackoffFactor)
InitConfigWrap(conf)
}

func InitConfigWrap(conf config.Section) {
conf.AddKnownKey(WSConfigKeyHeartbeatInterval, defaultHeartbeatInterval)
}

func GenerateConfig(ctx context.Context, conf config.Section) (*WSConfig, error) {
Expand All @@ -73,6 +82,7 @@ func GenerateConfig(ctx context.Context, conf config.Section) (*WSConfig, error)
WriteBufferSize: int(conf.GetByteSize(WSConfigKeyWriteBufferSize)),
InitialDelay: conf.GetDuration(ffresty.HTTPConfigRetryInitDelay),
MaximumDelay: conf.GetDuration(ffresty.HTTPConfigRetryMaxDelay),
DelayFactor: conf.GetFloat64(WSConfigDelayFactor),
InitialConnectAttempts: conf.GetInt(WSConfigKeyInitialConnectAttempts),
HTTPHeaders: conf.GetObject(ffresty.HTTPConfigHeaders),
AuthUsername: conf.GetString(ffresty.HTTPConfigAuthUsername),
Expand Down