diff --git a/README.md b/README.md index a47fa0c2..b8dde3bd 100644 --- a/README.md +++ b/README.md @@ -389,18 +389,27 @@ own logging system. ### Websocket ping-pong -The websocket package currently supports client-initiated pings only. +The websocket package supports configuring ping pong for both endpoints. -If your setup requires the server to be the initiator of a ping-pong (e.g. for web-based charge points), -you may disable ping-pong entirely and just rely on the heartbeat mechanism: +By default, the client sends a ping every 54 seconds and waits for a pong for 60 seconds, before timing out. +The values can be configured as follows: +```go +cfg := ws.NewClientTimeoutConfig() +cfg.PingPeriod = 10 * time.Second +cfg.PongWait = 20 * time.Second +websocketClient.SetTimeoutConfig(cfg) +``` +By default, the server does not send out any pings and waits for a ping from the client for 60 seconds, before timing out. +To configure the server to send out pings, the `PingPeriod` and `PongWait` must be set to a value greater than 0: ```go cfg := ws.NewServerTimeoutConfig() -cfg.PingWait = 0 // this instructs the server to wait forever +cfg.PingPeriod = 10 * time.Second +cfg.PongWait = 20 * time.Second websocketServer.SetTimeoutConfig(cfg) ``` -> A server-initiated ping may be supported in a future release. +To disable sending ping messages, set the `PingPeriod` value to `0`. ## OCPP 2.0.1 Usage diff --git a/example/1.6/cp/charge_point_sim.go b/example/1.6/cp/charge_point_sim.go index eb4b3434..ffad0c1f 100644 --- a/example/1.6/cp/charge_point_sim.go +++ b/example/1.6/cp/charge_point_sim.go @@ -64,10 +64,10 @@ func setupTlsChargePoint(chargePointID string) ocpp16.ChargePoint { } } // Create client with TLS config - client := ws.NewTLSClient(&tls.Config{ + client := ws.NewClient(ws.WithClientTLSConfig(&tls.Config{ RootCAs: certPool, Certificates: clientCertificates, - }) + })) return ocpp16.NewChargePoint(chargePointID, nil, client) } diff --git a/example/1.6/cs/central_system_sim.go b/example/1.6/cs/central_system_sim.go index 6417abf7..8eeee105 100644 --- a/example/1.6/cs/central_system_sim.go +++ b/example/1.6/cs/central_system_sim.go @@ -67,10 +67,10 @@ func setupTlsCentralSystem() ocpp16.CentralSystem { if !ok { log.Fatalf("no required %v found", envVarServerCertificateKey) } - server := ws.NewTLSServer(certificate, key, &tls.Config{ + server := ws.NewServer(ws.WithServerTLSConfig(certificate, key, &tls.Config{ ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: certPool, - }) + })) return ocpp16.NewCentralSystem(nil, server) } diff --git a/example/2.0.1/chargingstation/charging_station_sim.go b/example/2.0.1/chargingstation/charging_station_sim.go index 45b94790..091723ea 100644 --- a/example/2.0.1/chargingstation/charging_station_sim.go +++ b/example/2.0.1/chargingstation/charging_station_sim.go @@ -66,10 +66,10 @@ func setupTlsChargingStation(chargingStationID string) ocpp2.ChargingStation { } } // Create client with TLS config - client := ws.NewTLSClient(&tls.Config{ + client := ws.NewClient(ws.WithClientTLSConfig(&tls.Config{ RootCAs: certPool, Certificates: clientCertificates, - }) + })) return ocpp2.NewChargingStation(chargingStationID, nil, client) } diff --git a/example/2.0.1/csms/csms_sim.go b/example/2.0.1/csms/csms_sim.go index c6ddd050..57b88ec9 100644 --- a/example/2.0.1/csms/csms_sim.go +++ b/example/2.0.1/csms/csms_sim.go @@ -70,10 +70,10 @@ func setupTlsCentralSystem() ocpp2.CSMS { if !ok { log.Fatalf("no required %v found", envVarServerCertificateKey) } - server := ws.NewTLSServer(certificate, key, &tls.Config{ + server := ws.NewServer(ws.WithServerTLSConfig(certificate, key, &tls.Config{ ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: certPool, - }) + })) return ocpp2.NewCSMS(nil, server) } diff --git a/ocpp1.6/v16.go b/ocpp1.6/v16.go index e9a0721c..4fada30f 100644 --- a/ocpp1.6/v16.go +++ b/ocpp1.6/v16.go @@ -153,13 +153,13 @@ type ChargePoint interface { // if !ok { // log.Fatal("couldn't parse PEM certificate") // } -// cp := NewClient("someUniqueId", nil, ws.NewTLSClient(&tls.Config{ +// cp := NewClient("someUniqueId", nil, ws.NewClient(ws.WithClientTLSConfig(&tls.Config{ // RootCAs: certPool, -// }) +// })) // // For more advanced options, or if a customer networking/occpj layer is required, -// please refer to ocppj.Client and ws.WsClient. -func NewChargePoint(id string, endpoint *ocppj.Client, client ws.WsClient) ChargePoint { +// please refer to ocppj.Client and ws.Client. +func NewChargePoint(id string, endpoint *ocppj.Client, client ws.Client) ChargePoint { if client == nil { client = ws.NewClient() } @@ -338,8 +338,8 @@ type CentralSystem interface { // // If you need a TLS server, you may use the following: // -// cs := NewServer(nil, ws.NewTLSServer("certificatePath", "privateKeyPath")) -func NewCentralSystem(endpoint *ocppj.Server, server ws.WsServer) CentralSystem { +// cs := NewServer(nil, ws.NewServer(ws.WithServerTLSConfig("certificatePath", "privateKeyPath", nil))) +func NewCentralSystem(endpoint *ocppj.Server, server ws.Server) CentralSystem { if server == nil { server = ws.NewServer() } diff --git a/ocpp1.6_test/ocpp16_test.go b/ocpp1.6_test/ocpp16_test.go index 9a701f58..a46ee6ee 100644 --- a/ocpp1.6_test/ocpp16_test.go +++ b/ocpp1.6_test/ocpp16_test.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "fmt" "net" - "net/http" "reflect" "testing" "time" @@ -50,6 +49,10 @@ func (websocket MockWebSocket) TLSConnectionState() *tls.ConnectionState { return nil } +func (websocket MockWebSocket) IsConnected() bool { + return true +} + func NewMockWebSocket(id string) MockWebSocket { return MockWebSocket{id: id} } @@ -57,7 +60,7 @@ func NewMockWebSocket(id string) MockWebSocket { // ---------------------- MOCK WEBSOCKET SERVER ---------------------- type MockWebsocketServer struct { mock.Mock - ws.WsServer + ws.Server MessageHandler func(ws ws.Channel, data []byte) error NewClientHandler func(ws ws.Channel) CheckClientHandler ws.CheckClientHandler @@ -77,11 +80,11 @@ func (websocketServer *MockWebsocketServer) Write(webSocketId string, data []byt return args.Error(0) } -func (websocketServer *MockWebsocketServer) SetMessageHandler(handler func(ws ws.Channel, data []byte) error) { +func (websocketServer *MockWebsocketServer) SetMessageHandler(handler ws.MessageHandler) { websocketServer.MessageHandler = handler } -func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler func(ws ws.Channel)) { +func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler ws.ConnectedHandler) { websocketServer.NewClientHandler = handler } @@ -96,14 +99,14 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler ws.CheckClientHandler) { websocketServer.CheckClientHandler = handler } // ---------------------- MOCK WEBSOCKET CLIENT ---------------------- type MockWebsocketClient struct { mock.Mock - ws.WsClient + ws.Client MessageHandler func(data []byte) error ReconnectedHandler func() DisconnectedHandler func(err error) @@ -445,41 +448,6 @@ func (smartChargingListener *MockChargePointSmartChargingListener) OnGetComposit } // ---------------------- COMMON UTILITY METHODS ---------------------- -func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Server { - wsServer := ws.Server{} - wsServer.SetMessageHandler(func(ws ws.Channel, data []byte) error { - assert.NotNil(t, ws) - assert.NotNil(t, data) - if onMessage != nil { - response, err := onMessage(data) - assert.Nil(t, err) - if response != nil { - err = wsServer.Write(ws.ID(), data) - assert.Nil(t, err) - } - } - return nil - }) - return &wsServer -} - -func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Client { - wsClient := ws.Client{} - wsClient.SetMessageHandler(func(data []byte) error { - assert.NotNil(t, data) - if onMessage != nil { - response, err := onMessage(data) - assert.Nil(t, err) - if response != nil { - err = wsClient.Write(data) - assert.Nil(t, err) - } - } - return nil - }) - return &wsClient -} - type expectedCentralSystemOptions struct { clientId string rawWrittenMessage []byte diff --git a/ocpp2.0.1/v2.go b/ocpp2.0.1/v2.go index 338c7eb6..3bc38ac3 100644 --- a/ocpp2.0.1/v2.go +++ b/ocpp2.0.1/v2.go @@ -202,13 +202,13 @@ type ChargingStation interface { // if !ok { // log.Fatal("couldn't parse PEM certificate") // } -// cs := NewChargingStation("someUniqueId", nil, ws.NewTLSClient(&tls.Config{ +// cs := NewChargingStation("someUniqueId", nil, ws.NewClient(ws.WithClientTLSConfig(&tls.Config{ // RootCAs: certPool, -// }) +// })) // // For more advanced options, or if a custom networking/occpj layer is required, -// please refer to ocppj.Client and ws.WsClient. -func NewChargingStation(id string, endpoint *ocppj.Client, client ws.WsClient) ChargingStation { +// please refer to ocppj.Client and ws.Client. +func NewChargingStation(id string, endpoint *ocppj.Client, client ws.Client) ChargingStation { if client == nil { client = ws.NewClient() } @@ -414,8 +414,8 @@ type CSMS interface { // // If you need a TLS server, you may use the following: // -// csms := NewCSMS(nil, ws.NewTLSServer("certificatePath", "privateKeyPath")) -func NewCSMS(endpoint *ocppj.Server, server ws.WsServer) CSMS { +// csms := NewCSMS(nil, ws.NewServer(ws.WithServerTLSConfig("certificatePath", "privateKeyPath", nil))) +func NewCSMS(endpoint *ocppj.Server, server ws.Server) CSMS { if server == nil { server = ws.NewServer() } diff --git a/ocpp2.0.1_test/ocpp2_test.go b/ocpp2.0.1_test/ocpp2_test.go index edf56201..ee769483 100644 --- a/ocpp2.0.1_test/ocpp2_test.go +++ b/ocpp2.0.1_test/ocpp2_test.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "fmt" "net" - "net/http" "reflect" "testing" "time" @@ -59,6 +58,10 @@ func (websocket MockWebSocket) TLSConnectionState() *tls.ConnectionState { return nil } +func (websocket MockWebSocket) IsConnected() bool { + return true +} + func NewMockWebSocket(id string) MockWebSocket { return MockWebSocket{id: id} } @@ -67,7 +70,7 @@ func NewMockWebSocket(id string) MockWebSocket { type MockWebsocketServer struct { mock.Mock - ws.WsServer + ws.Server MessageHandler func(ws ws.Channel, data []byte) error NewClientHandler func(ws ws.Channel) CheckClientHandler ws.CheckClientHandler @@ -87,11 +90,11 @@ func (websocketServer *MockWebsocketServer) Write(webSocketId string, data []byt return args.Error(0) } -func (websocketServer *MockWebsocketServer) SetMessageHandler(handler func(ws ws.Channel, data []byte) error) { +func (websocketServer *MockWebsocketServer) SetMessageHandler(handler ws.MessageHandler) { websocketServer.MessageHandler = handler } -func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler func(ws ws.Channel)) { +func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler ws.ConnectedHandler) { websocketServer.NewClientHandler = handler } @@ -106,7 +109,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler ws.CheckClientHandler) { websocketServer.CheckClientHandler = handler } @@ -114,7 +117,7 @@ func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(i type MockWebsocketClient struct { mock.Mock - ws.WsClient + ws.Client MessageHandler func(data []byte) error ReconnectedHandler func() DisconnectedHandler func(err error) @@ -812,42 +815,6 @@ func (handler *MockCSMSTransactionsHandler) OnTransactionEvent(chargingStationID } // ---------------------- COMMON UTILITY METHODS ---------------------- - -func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Server { - wsServer := ws.Server{} - wsServer.SetMessageHandler(func(ws ws.Channel, data []byte) error { - assert.NotNil(t, ws) - assert.NotNil(t, data) - if onMessage != nil { - response, err := onMessage(data) - assert.Nil(t, err) - if response != nil { - err = wsServer.Write(ws.ID(), data) - assert.Nil(t, err) - } - } - return nil - }) - return &wsServer -} - -func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Client { - wsClient := ws.Client{} - wsClient.SetMessageHandler(func(data []byte) error { - assert.NotNil(t, data) - if onMessage != nil { - response, err := onMessage(data) - assert.Nil(t, err) - if response != nil { - err = wsClient.Write(data) - assert.Nil(t, err) - } - } - return nil - }) - return &wsClient -} - type expectedCSMSOptions struct { clientId string rawWrittenMessage []byte diff --git a/ocppj/central_system_test.go b/ocppj/central_system_test.go index e2506f68..e8eba034 100644 --- a/ocppj/central_system_test.go +++ b/ocppj/central_system_test.go @@ -400,7 +400,7 @@ func (suite *OcppJTestSuite) TestCentralSystemNewClientHandler() { suite.mockServer.NewClientHandler(channel) ok := <-connectedC assert.True(t, ok) - // Client state was created + // client state was created _, ok = suite.serverRequestMap.Get(mockClientID) assert.True(t, ok) } diff --git a/ocppj/client.go b/ocppj/client.go index 496bbc91..a66f0b6e 100644 --- a/ocppj/client.go +++ b/ocppj/client.go @@ -13,7 +13,7 @@ import ( // During message exchange, the two roles may be reversed (depending on the message direction), but a client struct remains associated to a charge point/charging station. type Client struct { Endpoint - client ws.WsClient + client ws.Client Id string requestHandler func(request ocpp.Request, requestId string, action string) responseHandler func(response ocpp.Response, requestId string) @@ -35,7 +35,7 @@ type Client struct { // // The wsClient parameter cannot be nil. Refer to the ws package for information on how to create and // customize a websocket client. -func NewClient(id string, wsClient ws.WsClient, dispatcher ClientDispatcher, stateHandler ClientState, profiles ...*ocpp.Profile) *Client { +func NewClient(id string, wsClient ws.Client, dispatcher ClientDispatcher, stateHandler ClientState, profiles ...*ocpp.Profile) *Client { endpoint := Endpoint{} if wsClient == nil { panic("wsClient parameter cannot be nil") diff --git a/ocppj/dispatcher.go b/ocppj/dispatcher.go index 05006d9c..4f62f878 100644 --- a/ocppj/dispatcher.go +++ b/ocppj/dispatcher.go @@ -54,7 +54,7 @@ type ClientDispatcher interface { // Sets the network client, so the dispatcher may send requests using the networking layer directly. // // This needs to be set before calling the Start method. If not, sending requests will fail. - SetNetworkClient(client ws.WsClient) + SetNetworkClient(client ws.Client) // Sets the state manager for pending requests in the dispatcher. // // The state should only be accessed by the dispatcher while running. @@ -90,7 +90,7 @@ type DefaultClientDispatcher struct { requestChannel chan bool readyForDispatch chan bool pendingRequestState ClientState - network ws.WsClient + network ws.Client mutex sync.RWMutex onRequestCancel func(requestID string, request ocpp.Request, err *ocpp.Error) timer *time.Timer @@ -149,7 +149,7 @@ func (d *DefaultClientDispatcher) Stop() { // TODO: clear pending requests? } -func (d *DefaultClientDispatcher) SetNetworkClient(client ws.WsClient) { +func (d *DefaultClientDispatcher) SetNetworkClient(client ws.Client) { d.network = client } @@ -335,7 +335,7 @@ type ServerDispatcher interface { // Sets the network server, so the dispatcher may send requests using the networking layer directly. // // This needs to be set before calling the Start method. If not, sending requests will fail. - SetNetworkServer(server ws.WsServer) + SetNetworkServer(server ws.Server) // Sets the state manager for pending requests in the dispatcher. // // The state should only be accessed by the dispatcher while running. @@ -371,7 +371,7 @@ type DefaultServerDispatcher struct { running bool stoppedC chan struct{} onRequestCancel CanceledRequestHandler - network ws.WsServer + network ws.Server mutex sync.RWMutex } @@ -442,7 +442,7 @@ func (d *DefaultServerDispatcher) DeleteClient(clientID string) { } } -func (d *DefaultServerDispatcher) SetNetworkServer(server ws.WsServer) { +func (d *DefaultServerDispatcher) SetNetworkServer(server ws.Server) { d.network = server } @@ -491,7 +491,7 @@ func (d *DefaultServerDispatcher) messagePump() { for { select { case <-d.stoppedC: - // Server was stopped + // server was stopped d.queueMap.Init() log.Info("stopped processing requests") return @@ -547,7 +547,7 @@ func (d *DefaultServerDispatcher) messagePump() { clientCtx.cancel() clientContextMap[clientID] = clientTimeoutContext{} } - // Client can now transmit again + // client can now transmit again clientQueue, ok = d.queueMap.Get(clientID) if ok { // Ready to transmit @@ -620,7 +620,7 @@ func (d *DefaultServerDispatcher) waitForTimeout(clientID string, clientCtx clie log.Debugf("timeout canceled for %s", clientID) } case <-d.stoppedC: - // Server was stopped, every pending timeout gets canceled + // server was stopped, every pending timeout gets canceled } } diff --git a/ocppj/mocks/mock_ClientDispatcher.go b/ocppj/mocks/mock_ClientDispatcher.go index afa36f5c..0dd86eff 100644 --- a/ocppj/mocks/mock_ClientDispatcher.go +++ b/ocppj/mocks/mock_ClientDispatcher.go @@ -260,7 +260,7 @@ func (_c *MockClientDispatcher_SendRequest_Call) RunAndReturn(run func(ocppj.Req } // SetNetworkClient provides a mock function with given fields: client -func (_m *MockClientDispatcher) SetNetworkClient(client ws.WsClient) { +func (_m *MockClientDispatcher) SetNetworkClient(client ws.Client) { _m.Called(client) } @@ -270,14 +270,14 @@ type MockClientDispatcher_SetNetworkClient_Call struct { } // SetNetworkClient is a helper method to define mock.On call -// - client ws.WsClient +// - client ws.Client func (_e *MockClientDispatcher_Expecter) SetNetworkClient(client interface{}) *MockClientDispatcher_SetNetworkClient_Call { return &MockClientDispatcher_SetNetworkClient_Call{Call: _e.mock.On("SetNetworkClient", client)} } -func (_c *MockClientDispatcher_SetNetworkClient_Call) Run(run func(client ws.WsClient)) *MockClientDispatcher_SetNetworkClient_Call { +func (_c *MockClientDispatcher_SetNetworkClient_Call) Run(run func(client ws.Client)) *MockClientDispatcher_SetNetworkClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(ws.WsClient)) + run(args[0].(ws.Client)) }) return _c } @@ -287,7 +287,7 @@ func (_c *MockClientDispatcher_SetNetworkClient_Call) Return() *MockClientDispat return _c } -func (_c *MockClientDispatcher_SetNetworkClient_Call) RunAndReturn(run func(ws.WsClient)) *MockClientDispatcher_SetNetworkClient_Call { +func (_c *MockClientDispatcher_SetNetworkClient_Call) RunAndReturn(run func(ws.Client)) *MockClientDispatcher_SetNetworkClient_Call { _c.Run(run) return _c } diff --git a/ocppj/mocks/mock_ServerDispatcher.go b/ocppj/mocks/mock_ServerDispatcher.go index 6ea739c4..85d0f206 100644 --- a/ocppj/mocks/mock_ServerDispatcher.go +++ b/ocppj/mocks/mock_ServerDispatcher.go @@ -217,7 +217,7 @@ func (_c *MockServerDispatcher_SendRequest_Call) RunAndReturn(run func(string, o } // SetNetworkServer provides a mock function with given fields: server -func (_m *MockServerDispatcher) SetNetworkServer(server ws.WsServer) { +func (_m *MockServerDispatcher) SetNetworkServer(server ws.Server) { _m.Called(server) } @@ -227,14 +227,14 @@ type MockServerDispatcher_SetNetworkServer_Call struct { } // SetNetworkServer is a helper method to define mock.On call -// - server ws.WsServer +// - server ws.Server func (_e *MockServerDispatcher_Expecter) SetNetworkServer(server interface{}) *MockServerDispatcher_SetNetworkServer_Call { return &MockServerDispatcher_SetNetworkServer_Call{Call: _e.mock.On("SetNetworkServer", server)} } -func (_c *MockServerDispatcher_SetNetworkServer_Call) Run(run func(server ws.WsServer)) *MockServerDispatcher_SetNetworkServer_Call { +func (_c *MockServerDispatcher_SetNetworkServer_Call) Run(run func(server ws.Server)) *MockServerDispatcher_SetNetworkServer_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(ws.WsServer)) + run(args[0].(ws.Server)) }) return _c } @@ -244,7 +244,7 @@ func (_c *MockServerDispatcher_SetNetworkServer_Call) Return() *MockServerDispat return _c } -func (_c *MockServerDispatcher_SetNetworkServer_Call) RunAndReturn(run func(ws.WsServer)) *MockServerDispatcher_SetNetworkServer_Call { +func (_c *MockServerDispatcher_SetNetworkServer_Call) RunAndReturn(run func(ws.Server)) *MockServerDispatcher_SetNetworkServer_Call { _c.Run(run) return _c } diff --git a/ocppj/ocppj_test.go b/ocppj/ocppj_test.go index e01cb257..371741ec 100644 --- a/ocppj/ocppj_test.go +++ b/ocppj/ocppj_test.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "fmt" "net" - "net/http" "reflect" "testing" @@ -42,6 +41,10 @@ func (websocket MockWebSocket) TLSConnectionState() *tls.ConnectionState { return nil } +func (websocket MockWebSocket) IsConnected() bool { + return true +} + func NewMockWebSocket(id string) MockWebSocket { return MockWebSocket{id: id} } @@ -50,8 +53,8 @@ func NewMockWebSocket(id string) MockWebSocket { type MockWebsocketServer struct { mock.Mock - ws.WsServer - MessageHandler func(ws ws.Channel, data []byte) error + ws.Server + MessageHandler ws.MessageHandler NewClientHandler func(ws ws.Channel) CheckClientHandler ws.CheckClientHandler DisconnectedClientHandler func(ws ws.Channel) @@ -71,11 +74,11 @@ func (websocketServer *MockWebsocketServer) Write(webSocketId string, data []byt return args.Error(0) } -func (websocketServer *MockWebsocketServer) SetMessageHandler(handler func(ws ws.Channel, data []byte) error) { +func (websocketServer *MockWebsocketServer) SetMessageHandler(handler ws.MessageHandler) { websocketServer.MessageHandler = handler } -func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler func(ws ws.Channel)) { +func (websocketServer *MockWebsocketServer) SetNewClientHandler(handler ws.ConnectedHandler) { websocketServer.NewClientHandler = handler } @@ -103,7 +106,7 @@ func (websocketServer *MockWebsocketServer) NewClient(websocketId string, client websocketServer.MethodCalled("NewClient", websocketId, client) } -func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { +func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler ws.CheckClientHandler) { websocketServer.CheckClientHandler = handler } @@ -111,7 +114,7 @@ func (websocketServer *MockWebsocketServer) SetCheckClientHandler(handler func(i type MockWebsocketClient struct { mock.Mock - ws.WsClient + ws.Client MessageHandler func(data []byte) error ReconnectedHandler func() DisconnectedHandler func(err error) @@ -232,8 +235,8 @@ func (m *MockUnsupportedResponse) GetFeatureName() string { // ---------------------- COMMON UTILITY METHODS ---------------------- -func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Server { - wsServer := ws.Server{} +func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) ws.Server { + wsServer := ws.NewServer() wsServer.SetMessageHandler(func(ws ws.Channel, data []byte) error { assert.NotNil(t, ws) assert.NotNil(t, data) @@ -247,11 +250,11 @@ func NewWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error } return nil }) - return &wsServer + return wsServer } -func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *ws.Client { - wsClient := ws.Client{} +func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) ws.Client { + wsClient := ws.NewClient() wsClient.SetMessageHandler(func(data []byte) error { assert.NotNil(t, data) if onMessage != nil { @@ -264,7 +267,7 @@ func NewWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error } return nil }) - return &wsClient + return wsClient } func ParseCall(endpoint *ocppj.Endpoint, state ocppj.ClientState, json string, t *testing.T) *ocppj.Call { diff --git a/ocppj/server.go b/ocppj/server.go index 1653e9cb..b7f3974c 100644 --- a/ocppj/server.go +++ b/ocppj/server.go @@ -13,7 +13,7 @@ import ( // During message exchange, the two roles may be reversed (depending on the message direction), but a server struct remains associated to a central system. type Server struct { Endpoint - server ws.WsServer + server ws.Server checkClientHandler ws.CheckClientHandler newClientHandler ClientHandler disconnectedClientHandler ClientHandler @@ -40,7 +40,7 @@ type InvalidMessageHook func(client ws.Channel, err *ocpp.Error, rawJson string, // s := ocppj.NewServer(ws.NewServer(), nil, nil) // // The dispatcher's associated ClientState will be set during initialization. -func NewServer(wsServer ws.WsServer, dispatcher ServerDispatcher, stateHandler ServerState, profiles ...*ocpp.Profile) *Server { +func NewServer(wsServer ws.Server, dispatcher ServerDispatcher, stateHandler ServerState, profiles ...*ocpp.Profile) *Server { if dispatcher == nil { dispatcher = NewDefaultServerDispatcher(NewFIFOQueueMap(0)) } diff --git a/ws/client.go b/ws/client.go new file mode 100644 index 00000000..1d224c80 --- /dev/null +++ b/ws/client.go @@ -0,0 +1,429 @@ +package ws + +import ( + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "math/rand" + "net/http" + "net/url" + "path" + "time" + + "github.com/gorilla/websocket" +) + +// ---------------------- CLIENT ---------------------- + +// Client defines a websocket client, needed to connect to a websocket server. +// The offered API are of asynchronous nature, and each incoming message is handled using callbacks. +// +// To create a new ws client, use: +// +// client := NewClient() +// +// If you need a secure websocket client instead, pass a tls.Config to the NewClient function: +// +// certPool, err := x509.SystemCertPool() +// if err != nil { +// log.Fatal(err) +// } +// // You may add more trusted certificates to the pool before creating the TLS Config +// client := NewClient(&tls.Config{ +// RootCAs: certPool, +// }) +// +// To add additional dial options, use: +// +// client.AddOption(func(*websocket.Dialer) { +// // Your option ... +// }) +// +// To add basic HTTP authentication, use: +// +// client.SetBasicAuth("username","password") +// +// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method. +// +// Using Start and Stop you can respectively open/close a websocket to a websocket server. +// +// To receive incoming messages, you will need to set your own handler using SetMessageHandler. +// To write data on the open socket, simply call the Write function. +type Client interface { + // Starts the client and attempts to connect to the server on a specified URL. + // If the connection fails, an error is returned. + // + // For example: + // err := client.Start("ws://localhost:8887/ws/1234") + // + // The function returns immediately, after the connection has been established. + // Incoming messages are passed automatically to the callback function, so no explicit read operation is required. + // + // To stop a running client, call the Stop function. + Start(url string) error + // Starts the client and attempts to connect to the server on a specified URL. + // If the connection fails, it keeps retrying with Backoff strategy from TimeoutConfig. + // + // For example: + // client.StartWithRetries("ws://localhost:8887/ws/1234") + // + // The function returns only when the connection has been established. + // Incoming messages are passed automatically to the callback function, so no explicit read operation is required. + // + // To stop a running client, call the Stop function. + StartWithRetries(url string) + // Stop closes the output of the websocket Channel, effectively closing the connection to the server with a normal closure. + Stop() + // Errors returns a channel for error messages. If it doesn't exist it es created. + // The channel is closed by the client when stopped. + // + // It is recommended to invoke this function before starting a client. + // Creating the error channel while the client is running may lead to unexpected behavior. + Errors() <-chan error + // Sets a callback function for all incoming messages. + SetMessageHandler(handler func(data []byte) error) + // Set custom timeout configuration parameters. If not passed, a default ClientTimeoutConfig struct will be used. + // + // This function must be called before connecting to the server, otherwise it may lead to unexpected behavior. + SetTimeoutConfig(config ClientTimeoutConfig) + // Sets a callback function for receiving notifications about an unexpected disconnection from the server. + // The callback is invoked even if the automatic reconnection mechanism is active. + // + // If the client was stopped using the Stop function, the callback will NOT be invoked. + SetDisconnectedHandler(handler func(err error)) + // Sets a callback function for receiving notifications whenever the connection to the server is re-established. + // Connections are re-established automatically thanks to the auto-reconnection mechanism. + // + // If set, the DisconnectedHandler will always be invoked before the Reconnected callback is invoked. + SetReconnectedHandler(handler func()) + // IsConnected Returns information about the current connection status. + // If the client is currently attempting to auto-reconnect to the server, the function returns false. + IsConnected() bool + // Sends a message to the server over the websocket. + // + // The data is queued and will be sent asynchronously in the background. + Write(data []byte) error + // Adds a websocket option to the client. + AddOption(option interface{}) + // SetRequestedSubProtocol will negotiate the specified sub-protocol during the websocket handshake. + // Internally this creates a dialer option and invokes the AddOption method on the client. + // + // Duplicates generated by invoking this method multiple times will be ignored. + SetRequestedSubProtocol(subProto string) + // SetBasicAuth adds basic authentication credentials, to use when connecting to the server. + // The credentials are automatically encoded in base64. + SetBasicAuth(username string, password string) + // SetHeaderValue sets a value on the HTTP header sent when opening a websocket connection to the server. + // + // The function overwrites previous header fields with the same key. + SetHeaderValue(key string, value string) +} + +// client is the default implementation of a Websocket client. +// +// Use the NewClient function to create a new client. +type client struct { + webSocket *webSocket + url url.URL + messageHandler func(data []byte) error + dialOptions []func(*websocket.Dialer) + header http.Header + timeoutConfig ClientTimeoutConfig + onDisconnected func(err error) + onReconnected func() + errC chan error + reconnectC chan struct{} // used for signaling, that a reconnection attempt should be interrupted +} + +// ClientOpt is a function that can be used to set options on a client during creation. +type ClientOpt func(c *client) + +// WithClientTLSConfig sets the TLS configuration for the client. +// If the passed tlsConfig is nil, the client will not use TLS. +func WithClientTLSConfig(tlsConfig *tls.Config) ClientOpt { + return func(c *client) { + if tlsConfig != nil { + c.dialOptions = append(c.dialOptions, func(dialer *websocket.Dialer) { + dialer.TLSClientConfig = tlsConfig + }) + } + } +} + +// NewClient creates a new websocket client. +// +// If the optional tlsConfig is not nil, and the server supports secure communication, +// the websocket channel will use TLS. +// +// Additional options may be added using the AddOption function. +// +// Basic authentication can be set using the SetBasicAuth function. +// +// By default, the client will not negotiate any sub-protocol. This value needs to be set via the +// respective SetRequestedSubProtocol method. +// +// To set a client certificate, you may do: +// +// certificate, _ := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) +// clientCertificates := []tls.Certificate{certificate} +// client := ws.NewClient(ws.WithClientTLSConfig(&tls.Config{ +// RootCAs: certPool, +// Certificates: clientCertificates, +// })) +// +// You can set any other TLS option within the same constructor config as well. +// For example, if you wish to test connecting to a server having a +// self-signed certificate (do not use in production!), pass: +// +// InsecureSkipVerify: true +func NewClient(opts ...ClientOpt) Client { + c := &client{ + dialOptions: []func(*websocket.Dialer){}, + timeoutConfig: NewClientTimeoutConfig(), + reconnectC: make(chan struct{}, 1), + header: http.Header{}, + } + for _, o := range opts { + o(c) + } + return c +} + +func (c *client) SetMessageHandler(handler func(data []byte) error) { + c.messageHandler = handler +} + +func (c *client) SetTimeoutConfig(config ClientTimeoutConfig) { + c.timeoutConfig = config +} + +func (c *client) SetDisconnectedHandler(handler func(err error)) { + c.onDisconnected = handler +} + +func (c *client) SetReconnectedHandler(handler func()) { + c.onReconnected = handler +} + +func (c *client) AddOption(option interface{}) { + dialOption, ok := option.(func(*websocket.Dialer)) + if ok { + c.dialOptions = append(c.dialOptions, dialOption) + } +} + +func (c *client) SetRequestedSubProtocol(subProto string) { + opt := func(dialer *websocket.Dialer) { + alreadyExists := false + for _, proto := range dialer.Subprotocols { + if proto == subProto { + alreadyExists = true + break + } + } + if !alreadyExists { + dialer.Subprotocols = append(dialer.Subprotocols, subProto) + } + } + c.AddOption(opt) +} + +func (c *client) SetBasicAuth(username string, password string) { + c.header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(username+":"+password))) +} + +func (c *client) SetHeaderValue(key string, value string) { + c.header.Set(key, value) +} + +func (c *client) getReadTimeout() time.Time { + if c.timeoutConfig.PongWait == 0 { + return time.Time{} + } + return time.Now().Add(c.timeoutConfig.PongWait) +} + +func (c *client) handleReconnection() { + log.Info("started automatic reconnection handler") + delay := c.timeoutConfig.RetryBackOffWaitMinimum + time.Duration(rand.Intn(c.timeoutConfig.RetryBackOffRandomRange+1))*time.Second + reconnectionAttempts := 1 + for { + // Wait before reconnecting + select { + case <-time.After(delay): + case <-c.reconnectC: + log.Info("automatic reconnection aborted") + return + } + + log.Info("reconnecting... attempt", reconnectionAttempts) + err := c.Start(c.url.String()) + if err == nil { + // Re-connection was successful + log.Info("reconnected successfully to server") + if c.onReconnected != nil { + c.onReconnected() + } + return + } + c.error(fmt.Errorf("reconnection failed: %w", err)) + + if reconnectionAttempts < c.timeoutConfig.RetryBackOffRepeatTimes { + // Re-connection failed, double the delay + delay *= 2 + delay += time.Duration(rand.Intn(c.timeoutConfig.RetryBackOffRandomRange+1)) * time.Second + } + reconnectionAttempts += 1 + } +} + +func (c *client) IsConnected() bool { + if c.webSocket == nil { + return false + } + return c.webSocket.IsConnected() +} + +func (c *client) Write(data []byte) error { + if !c.IsConnected() { + return fmt.Errorf("client is currently not connected, cannot send data") + } + log.Debugf("queuing data for server") + return c.webSocket.Write(data) +} + +func (c *client) StartWithRetries(urlStr string) { + err := c.Start(urlStr) + if err != nil { + log.Info("Connection error:", err) + c.handleReconnection() + } +} + +func (c *client) Start(urlStr string) error { + u, err := url.Parse(urlStr) + if err != nil { + return err + } + c.url = *u + if c.reconnectC == nil { + c.reconnectC = make(chan struct{}, 1) + } + + dialer := websocket.Dialer{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: c.timeoutConfig.HandshakeTimeout, + Subprotocols: []string{}, + } + for _, option := range c.dialOptions { + option(&dialer) + } + // Connect + log.Info("connecting to server") + ws, resp, err := dialer.Dial(urlStr, c.header) + if err != nil { + if resp != nil { + httpError := HttpConnectionError{Message: err.Error(), HttpStatus: resp.Status, HttpCode: resp.StatusCode} + // Parse http response details + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if body != nil { + httpError.Details = string(body) + } + err = httpError + } + return err + } + + // The id of the charge point is the final path element + id := path.Base(u.Path) + + // Create web socket, state is automatically set to connected + c.webSocket = newWebSocket( + id, + ws, + resp.TLS, + NewDefaultWebSocketConfig( + c.timeoutConfig.WriteWait, + 0, + c.timeoutConfig.PingPeriod, + c.timeoutConfig.PongWait, + ), + c.handleMessage, + c.handleDisconnect, + func(_ Channel, err error) { + c.error(err) + }, + ) + log.Infof("connected to server as %s", id) + // Start reader and write routine + c.webSocket.run() + return nil +} + +func (c *client) Stop() { + log.Infof("closing connection to server") + if c.IsConnected() { + // Attempt to gracefully shut down the connection + err := c.webSocket.Close(websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}) + if err != nil { + c.error(err) + } + } + // Notify reconnection goroutine to stop (if any) + select { + case <-c.reconnectC: + // Already closed, ignore + break + default: + // Channel is open, signal reconnection to stop + c.reconnectC <- struct{}{} + } + // Close error channel if any + select { + case <-c.errC: + // Already closed, ignore + break + default: + // Channel is open, close it + if c.errC != nil { + close(c.errC) + } + } + // Connection will close asynchronously and invoke the onDisconnected handler +} + +func (c *client) Errors() <-chan error { + if c.errC == nil { + c.errC = make(chan error, 1) + } + return c.errC +} + +// --------- Internal callbacks webSocket -> client --------- +func (c *client) handleMessage(_ Channel, data []byte) error { + if c.messageHandler != nil { + return c.messageHandler(data) + } + return fmt.Errorf("no message handler set") +} + +func (c *client) handleDisconnect(_ Channel, err error) { + if c.onDisconnected != nil { + // Notify upper layer of disconnect + c.onDisconnected(err) + } + if err != nil { + // Disconnect was forced, do reconnect + c.handleReconnection() + } +} + +func (c *client) error(err error) { + log.Error(err) + if c.errC != nil { + c.errC <- err + } +} diff --git a/ws/mocks/mock_Channel.go b/ws/mocks/mock_Channel.go index 00d32e51..cf5d228e 100644 --- a/ws/mocks/mock_Channel.go +++ b/ws/mocks/mock_Channel.go @@ -68,6 +68,51 @@ func (_c *MockChannel_ID_Call) RunAndReturn(run func() string) *MockChannel_ID_C return _c } +// IsConnected provides a mock function with no fields +func (_m *MockChannel) IsConnected() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsConnected") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockChannel_IsConnected_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsConnected' +type MockChannel_IsConnected_Call struct { + *mock.Call +} + +// IsConnected is a helper method to define mock.On call +func (_e *MockChannel_Expecter) IsConnected() *MockChannel_IsConnected_Call { + return &MockChannel_IsConnected_Call{Call: _e.mock.On("IsConnected")} +} + +func (_c *MockChannel_IsConnected_Call) Run(run func()) *MockChannel_IsConnected_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChannel_IsConnected_Call) Return(_a0 bool) *MockChannel_IsConnected_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockChannel_IsConnected_Call) RunAndReturn(run func() bool) *MockChannel_IsConnected_Call { + _c.Call.Return(run) + return _c +} + // RemoteAddr provides a mock function with no fields func (_m *MockChannel) RemoteAddr() net.Addr { ret := _m.Called() diff --git a/ws/mocks/mock_Client.go b/ws/mocks/mock_Client.go new file mode 100644 index 00000000..3d0dc510 --- /dev/null +++ b/ws/mocks/mock_Client.go @@ -0,0 +1,550 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + ws "github.com/lorenzodonini/ocpp-go/ws" + mock "github.com/stretchr/testify/mock" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// AddOption provides a mock function with given fields: option +func (_m *MockClient) AddOption(option interface{}) { + _m.Called(option) +} + +// MockClient_AddOption_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddOption' +type MockClient_AddOption_Call struct { + *mock.Call +} + +// AddOption is a helper method to define mock.On call +// - option interface{} +func (_e *MockClient_Expecter) AddOption(option interface{}) *MockClient_AddOption_Call { + return &MockClient_AddOption_Call{Call: _e.mock.On("AddOption", option)} +} + +func (_c *MockClient_AddOption_Call) Run(run func(option interface{})) *MockClient_AddOption_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockClient_AddOption_Call) Return() *MockClient_AddOption_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_AddOption_Call) RunAndReturn(run func(interface{})) *MockClient_AddOption_Call { + _c.Run(run) + return _c +} + +// Errors provides a mock function with no fields +func (_m *MockClient) Errors() <-chan error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Errors") + } + + var r0 <-chan error + if rf, ok := ret.Get(0).(func() <-chan error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan error) + } + } + + return r0 +} + +// MockClient_Errors_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Errors' +type MockClient_Errors_Call struct { + *mock.Call +} + +// Errors is a helper method to define mock.On call +func (_e *MockClient_Expecter) Errors() *MockClient_Errors_Call { + return &MockClient_Errors_Call{Call: _e.mock.On("Errors")} +} + +func (_c *MockClient_Errors_Call) Run(run func()) *MockClient_Errors_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Errors_Call) Return(_a0 <-chan error) *MockClient_Errors_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Errors_Call) RunAndReturn(run func() <-chan error) *MockClient_Errors_Call { + _c.Call.Return(run) + return _c +} + +// IsConnected provides a mock function with no fields +func (_m *MockClient) IsConnected() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsConnected") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockClient_IsConnected_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsConnected' +type MockClient_IsConnected_Call struct { + *mock.Call +} + +// IsConnected is a helper method to define mock.On call +func (_e *MockClient_Expecter) IsConnected() *MockClient_IsConnected_Call { + return &MockClient_IsConnected_Call{Call: _e.mock.On("IsConnected")} +} + +func (_c *MockClient_IsConnected_Call) Run(run func()) *MockClient_IsConnected_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_IsConnected_Call) Return(_a0 bool) *MockClient_IsConnected_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_IsConnected_Call) RunAndReturn(run func() bool) *MockClient_IsConnected_Call { + _c.Call.Return(run) + return _c +} + +// SetBasicAuth provides a mock function with given fields: username, password +func (_m *MockClient) SetBasicAuth(username string, password string) { + _m.Called(username, password) +} + +// MockClient_SetBasicAuth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetBasicAuth' +type MockClient_SetBasicAuth_Call struct { + *mock.Call +} + +// SetBasicAuth is a helper method to define mock.On call +// - username string +// - password string +func (_e *MockClient_Expecter) SetBasicAuth(username interface{}, password interface{}) *MockClient_SetBasicAuth_Call { + return &MockClient_SetBasicAuth_Call{Call: _e.mock.On("SetBasicAuth", username, password)} +} + +func (_c *MockClient_SetBasicAuth_Call) Run(run func(username string, password string)) *MockClient_SetBasicAuth_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(string)) + }) + return _c +} + +func (_c *MockClient_SetBasicAuth_Call) Return() *MockClient_SetBasicAuth_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetBasicAuth_Call) RunAndReturn(run func(string, string)) *MockClient_SetBasicAuth_Call { + _c.Run(run) + return _c +} + +// SetDisconnectedHandler provides a mock function with given fields: handler +func (_m *MockClient) SetDisconnectedHandler(handler func(error)) { + _m.Called(handler) +} + +// MockClient_SetDisconnectedHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDisconnectedHandler' +type MockClient_SetDisconnectedHandler_Call struct { + *mock.Call +} + +// SetDisconnectedHandler is a helper method to define mock.On call +// - handler func(error) +func (_e *MockClient_Expecter) SetDisconnectedHandler(handler interface{}) *MockClient_SetDisconnectedHandler_Call { + return &MockClient_SetDisconnectedHandler_Call{Call: _e.mock.On("SetDisconnectedHandler", handler)} +} + +func (_c *MockClient_SetDisconnectedHandler_Call) Run(run func(handler func(error))) *MockClient_SetDisconnectedHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(error))) + }) + return _c +} + +func (_c *MockClient_SetDisconnectedHandler_Call) Return() *MockClient_SetDisconnectedHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetDisconnectedHandler_Call) RunAndReturn(run func(func(error))) *MockClient_SetDisconnectedHandler_Call { + _c.Run(run) + return _c +} + +// SetHeaderValue provides a mock function with given fields: key, value +func (_m *MockClient) SetHeaderValue(key string, value string) { + _m.Called(key, value) +} + +// MockClient_SetHeaderValue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeaderValue' +type MockClient_SetHeaderValue_Call struct { + *mock.Call +} + +// SetHeaderValue is a helper method to define mock.On call +// - key string +// - value string +func (_e *MockClient_Expecter) SetHeaderValue(key interface{}, value interface{}) *MockClient_SetHeaderValue_Call { + return &MockClient_SetHeaderValue_Call{Call: _e.mock.On("SetHeaderValue", key, value)} +} + +func (_c *MockClient_SetHeaderValue_Call) Run(run func(key string, value string)) *MockClient_SetHeaderValue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(string)) + }) + return _c +} + +func (_c *MockClient_SetHeaderValue_Call) Return() *MockClient_SetHeaderValue_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetHeaderValue_Call) RunAndReturn(run func(string, string)) *MockClient_SetHeaderValue_Call { + _c.Run(run) + return _c +} + +// SetMessageHandler provides a mock function with given fields: handler +func (_m *MockClient) SetMessageHandler(handler func([]byte) error) { + _m.Called(handler) +} + +// MockClient_SetMessageHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMessageHandler' +type MockClient_SetMessageHandler_Call struct { + *mock.Call +} + +// SetMessageHandler is a helper method to define mock.On call +// - handler func([]byte) error +func (_e *MockClient_Expecter) SetMessageHandler(handler interface{}) *MockClient_SetMessageHandler_Call { + return &MockClient_SetMessageHandler_Call{Call: _e.mock.On("SetMessageHandler", handler)} +} + +func (_c *MockClient_SetMessageHandler_Call) Run(run func(handler func([]byte) error)) *MockClient_SetMessageHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func([]byte) error)) + }) + return _c +} + +func (_c *MockClient_SetMessageHandler_Call) Return() *MockClient_SetMessageHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetMessageHandler_Call) RunAndReturn(run func(func([]byte) error)) *MockClient_SetMessageHandler_Call { + _c.Run(run) + return _c +} + +// SetReconnectedHandler provides a mock function with given fields: handler +func (_m *MockClient) SetReconnectedHandler(handler func()) { + _m.Called(handler) +} + +// MockClient_SetReconnectedHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetReconnectedHandler' +type MockClient_SetReconnectedHandler_Call struct { + *mock.Call +} + +// SetReconnectedHandler is a helper method to define mock.On call +// - handler func() +func (_e *MockClient_Expecter) SetReconnectedHandler(handler interface{}) *MockClient_SetReconnectedHandler_Call { + return &MockClient_SetReconnectedHandler_Call{Call: _e.mock.On("SetReconnectedHandler", handler)} +} + +func (_c *MockClient_SetReconnectedHandler_Call) Run(run func(handler func())) *MockClient_SetReconnectedHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func())) + }) + return _c +} + +func (_c *MockClient_SetReconnectedHandler_Call) Return() *MockClient_SetReconnectedHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetReconnectedHandler_Call) RunAndReturn(run func(func())) *MockClient_SetReconnectedHandler_Call { + _c.Run(run) + return _c +} + +// SetRequestedSubProtocol provides a mock function with given fields: subProto +func (_m *MockClient) SetRequestedSubProtocol(subProto string) { + _m.Called(subProto) +} + +// MockClient_SetRequestedSubProtocol_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRequestedSubProtocol' +type MockClient_SetRequestedSubProtocol_Call struct { + *mock.Call +} + +// SetRequestedSubProtocol is a helper method to define mock.On call +// - subProto string +func (_e *MockClient_Expecter) SetRequestedSubProtocol(subProto interface{}) *MockClient_SetRequestedSubProtocol_Call { + return &MockClient_SetRequestedSubProtocol_Call{Call: _e.mock.On("SetRequestedSubProtocol", subProto)} +} + +func (_c *MockClient_SetRequestedSubProtocol_Call) Run(run func(subProto string)) *MockClient_SetRequestedSubProtocol_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockClient_SetRequestedSubProtocol_Call) Return() *MockClient_SetRequestedSubProtocol_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetRequestedSubProtocol_Call) RunAndReturn(run func(string)) *MockClient_SetRequestedSubProtocol_Call { + _c.Run(run) + return _c +} + +// SetTimeoutConfig provides a mock function with given fields: config +func (_m *MockClient) SetTimeoutConfig(config ws.ClientTimeoutConfig) { + _m.Called(config) +} + +// MockClient_SetTimeoutConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTimeoutConfig' +type MockClient_SetTimeoutConfig_Call struct { + *mock.Call +} + +// SetTimeoutConfig is a helper method to define mock.On call +// - config ws.ClientTimeoutConfig +func (_e *MockClient_Expecter) SetTimeoutConfig(config interface{}) *MockClient_SetTimeoutConfig_Call { + return &MockClient_SetTimeoutConfig_Call{Call: _e.mock.On("SetTimeoutConfig", config)} +} + +func (_c *MockClient_SetTimeoutConfig_Call) Run(run func(config ws.ClientTimeoutConfig)) *MockClient_SetTimeoutConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.ClientTimeoutConfig)) + }) + return _c +} + +func (_c *MockClient_SetTimeoutConfig_Call) Return() *MockClient_SetTimeoutConfig_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_SetTimeoutConfig_Call) RunAndReturn(run func(ws.ClientTimeoutConfig)) *MockClient_SetTimeoutConfig_Call { + _c.Run(run) + return _c +} + +// Start provides a mock function with given fields: url +func (_m *MockClient) Start(url string) error { + ret := _m.Called(url) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(url) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockClient_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - url string +func (_e *MockClient_Expecter) Start(url interface{}) *MockClient_Start_Call { + return &MockClient_Start_Call{Call: _e.mock.On("Start", url)} +} + +func (_c *MockClient_Start_Call) Run(run func(url string)) *MockClient_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockClient_Start_Call) Return(_a0 error) *MockClient_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Start_Call) RunAndReturn(run func(string) error) *MockClient_Start_Call { + _c.Call.Return(run) + return _c +} + +// StartWithRetries provides a mock function with given fields: url +func (_m *MockClient) StartWithRetries(url string) { + _m.Called(url) +} + +// MockClient_StartWithRetries_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StartWithRetries' +type MockClient_StartWithRetries_Call struct { + *mock.Call +} + +// StartWithRetries is a helper method to define mock.On call +// - url string +func (_e *MockClient_Expecter) StartWithRetries(url interface{}) *MockClient_StartWithRetries_Call { + return &MockClient_StartWithRetries_Call{Call: _e.mock.On("StartWithRetries", url)} +} + +func (_c *MockClient_StartWithRetries_Call) Run(run func(url string)) *MockClient_StartWithRetries_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockClient_StartWithRetries_Call) Return() *MockClient_StartWithRetries_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_StartWithRetries_Call) RunAndReturn(run func(string)) *MockClient_StartWithRetries_Call { + _c.Run(run) + return _c +} + +// Stop provides a mock function with no fields +func (_m *MockClient) Stop() { + _m.Called() +} + +// MockClient_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockClient_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockClient_Expecter) Stop() *MockClient_Stop_Call { + return &MockClient_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockClient_Stop_Call) Run(run func()) *MockClient_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Stop_Call) Return() *MockClient_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_Stop_Call) RunAndReturn(run func()) *MockClient_Stop_Call { + _c.Run(run) + return _c +} + +// Write provides a mock function with given fields: data +func (_m *MockClient) Write(data []byte) error { + ret := _m.Called(data) + + if len(ret) == 0 { + panic("no return value specified for Write") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockClient_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write' +type MockClient_Write_Call struct { + *mock.Call +} + +// Write is a helper method to define mock.On call +// - data []byte +func (_e *MockClient_Expecter) Write(data interface{}) *MockClient_Write_Call { + return &MockClient_Write_Call{Call: _e.mock.On("Write", data)} +} + +func (_c *MockClient_Write_Call) Run(run func(data []byte)) *MockClient_Write_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockClient_Write_Call) Return(_a0 error) *MockClient_Write_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Write_Call) RunAndReturn(run func([]byte) error) *MockClient_Write_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_ClientOpt.go b/ws/mocks/mock_ClientOpt.go new file mode 100644 index 00000000..2272797c --- /dev/null +++ b/ws/mocks/mock_ClientOpt.go @@ -0,0 +1,65 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MockClientOpt is an autogenerated mock type for the ClientOpt type +type MockClientOpt struct { + mock.Mock +} + +type MockClientOpt_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClientOpt) EXPECT() *MockClientOpt_Expecter { + return &MockClientOpt_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: c +func (_m *MockClientOpt) Execute(c *ws.client) { + _m.Called(c) +} + +// MockClientOpt_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockClientOpt_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - c *ws.client +func (_e *MockClientOpt_Expecter) Execute(c interface{}) *MockClientOpt_Execute_Call { + return &MockClientOpt_Execute_Call{Call: _e.mock.On("Execute", c)} +} + +func (_c *MockClientOpt_Execute_Call) Run(run func(c *ws.client)) *MockClientOpt_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*ws.client)) + }) + return _c +} + +func (_c *MockClientOpt_Execute_Call) Return() *MockClientOpt_Execute_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClientOpt_Execute_Call) RunAndReturn(run func(*ws.client)) *MockClientOpt_Execute_Call { + _c.Run(run) + return _c +} + +// NewMockClientOpt creates a new instance of MockClientOpt. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClientOpt(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClientOpt { + mock := &MockClientOpt{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_ConnectedHandler.go b/ws/mocks/mock_ConnectedHandler.go new file mode 100644 index 00000000..51c89a66 --- /dev/null +++ b/ws/mocks/mock_ConnectedHandler.go @@ -0,0 +1,68 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + ws "github.com/lorenzodonini/ocpp-go/ws" + mock "github.com/stretchr/testify/mock" +) + +// MockConnectedHandler is an autogenerated mock type for the ConnectedHandler type +type MockConnectedHandler struct { + mock.Mock +} + +type MockConnectedHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockConnectedHandler) EXPECT() *MockConnectedHandler_Expecter { + return &MockConnectedHandler_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: c +func (_m *MockConnectedHandler) Execute(c ws.Channel) { + _m.Called(c) +} + +// MockConnectedHandler_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockConnectedHandler_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - c ws.Channel +func (_e *MockConnectedHandler_Expecter) Execute(c interface{}) *MockConnectedHandler_Execute_Call { + return &MockConnectedHandler_Execute_Call{Call: _e.mock.On("Execute", c)} +} + +func (_c *MockConnectedHandler_Execute_Call) Run(run func(c ws.Channel)) *MockConnectedHandler_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.Channel)) + }) + return _c +} + +func (_c *MockConnectedHandler_Execute_Call) Return() *MockConnectedHandler_Execute_Call { + _c.Call.Return() + return _c +} + +func (_c *MockConnectedHandler_Execute_Call) RunAndReturn(run func(ws.Channel)) *MockConnectedHandler_Execute_Call { + _c.Run(run) + return _c +} + +// NewMockConnectedHandler creates a new instance of MockConnectedHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockConnectedHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockConnectedHandler { + mock := &MockConnectedHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_DisconnectedHandler.go b/ws/mocks/mock_DisconnectedHandler.go new file mode 100644 index 00000000..a2026741 --- /dev/null +++ b/ws/mocks/mock_DisconnectedHandler.go @@ -0,0 +1,69 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + ws "github.com/lorenzodonini/ocpp-go/ws" + mock "github.com/stretchr/testify/mock" +) + +// MockDisconnectedHandler is an autogenerated mock type for the DisconnectedHandler type +type MockDisconnectedHandler struct { + mock.Mock +} + +type MockDisconnectedHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDisconnectedHandler) EXPECT() *MockDisconnectedHandler_Expecter { + return &MockDisconnectedHandler_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: c, err +func (_m *MockDisconnectedHandler) Execute(c ws.Channel, err error) { + _m.Called(c, err) +} + +// MockDisconnectedHandler_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockDisconnectedHandler_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - c ws.Channel +// - err error +func (_e *MockDisconnectedHandler_Expecter) Execute(c interface{}, err interface{}) *MockDisconnectedHandler_Execute_Call { + return &MockDisconnectedHandler_Execute_Call{Call: _e.mock.On("Execute", c, err)} +} + +func (_c *MockDisconnectedHandler_Execute_Call) Run(run func(c ws.Channel, err error)) *MockDisconnectedHandler_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.Channel), args[1].(error)) + }) + return _c +} + +func (_c *MockDisconnectedHandler_Execute_Call) Return() *MockDisconnectedHandler_Execute_Call { + _c.Call.Return() + return _c +} + +func (_c *MockDisconnectedHandler_Execute_Call) RunAndReturn(run func(ws.Channel, error)) *MockDisconnectedHandler_Execute_Call { + _c.Run(run) + return _c +} + +// NewMockDisconnectedHandler creates a new instance of MockDisconnectedHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDisconnectedHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDisconnectedHandler { + mock := &MockDisconnectedHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_ErrorHandler.go b/ws/mocks/mock_ErrorHandler.go new file mode 100644 index 00000000..2f289861 --- /dev/null +++ b/ws/mocks/mock_ErrorHandler.go @@ -0,0 +1,69 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + ws "github.com/lorenzodonini/ocpp-go/ws" + mock "github.com/stretchr/testify/mock" +) + +// MockErrorHandler is an autogenerated mock type for the ErrorHandler type +type MockErrorHandler struct { + mock.Mock +} + +type MockErrorHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockErrorHandler) EXPECT() *MockErrorHandler_Expecter { + return &MockErrorHandler_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: c, err +func (_m *MockErrorHandler) Execute(c ws.Channel, err error) { + _m.Called(c, err) +} + +// MockErrorHandler_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockErrorHandler_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - c ws.Channel +// - err error +func (_e *MockErrorHandler_Expecter) Execute(c interface{}, err interface{}) *MockErrorHandler_Execute_Call { + return &MockErrorHandler_Execute_Call{Call: _e.mock.On("Execute", c, err)} +} + +func (_c *MockErrorHandler_Execute_Call) Run(run func(c ws.Channel, err error)) *MockErrorHandler_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.Channel), args[1].(error)) + }) + return _c +} + +func (_c *MockErrorHandler_Execute_Call) Return() *MockErrorHandler_Execute_Call { + _c.Call.Return() + return _c +} + +func (_c *MockErrorHandler_Execute_Call) RunAndReturn(run func(ws.Channel, error)) *MockErrorHandler_Execute_Call { + _c.Run(run) + return _c +} + +// NewMockErrorHandler creates a new instance of MockErrorHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockErrorHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockErrorHandler { + mock := &MockErrorHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_MessageHandler.go b/ws/mocks/mock_MessageHandler.go new file mode 100644 index 00000000..7f1000c5 --- /dev/null +++ b/ws/mocks/mock_MessageHandler.go @@ -0,0 +1,82 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + ws "github.com/lorenzodonini/ocpp-go/ws" + mock "github.com/stretchr/testify/mock" +) + +// MockMessageHandler is an autogenerated mock type for the MessageHandler type +type MockMessageHandler struct { + mock.Mock +} + +type MockMessageHandler_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMessageHandler) EXPECT() *MockMessageHandler_Expecter { + return &MockMessageHandler_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: c, data +func (_m *MockMessageHandler) Execute(c ws.Channel, data []byte) error { + ret := _m.Called(c, data) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 error + if rf, ok := ret.Get(0).(func(ws.Channel, []byte) error); ok { + r0 = rf(c, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMessageHandler_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockMessageHandler_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - c ws.Channel +// - data []byte +func (_e *MockMessageHandler_Expecter) Execute(c interface{}, data interface{}) *MockMessageHandler_Execute_Call { + return &MockMessageHandler_Execute_Call{Call: _e.mock.On("Execute", c, data)} +} + +func (_c *MockMessageHandler_Execute_Call) Run(run func(c ws.Channel, data []byte)) *MockMessageHandler_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.Channel), args[1].([]byte)) + }) + return _c +} + +func (_c *MockMessageHandler_Execute_Call) Return(_a0 error) *MockMessageHandler_Execute_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageHandler_Execute_Call) RunAndReturn(run func(ws.Channel, []byte) error) *MockMessageHandler_Execute_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMessageHandler creates a new instance of MockMessageHandler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMessageHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMessageHandler { + mock := &MockMessageHandler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_Server.go b/ws/mocks/mock_Server.go new file mode 100644 index 00000000..e95b47ab --- /dev/null +++ b/ws/mocks/mock_Server.go @@ -0,0 +1,617 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import ( + net "net" + http "net/http" + + mock "github.com/stretchr/testify/mock" + + websocket "github.com/gorilla/websocket" + + ws "github.com/lorenzodonini/ocpp-go/ws" +) + +// MockServer is an autogenerated mock type for the Server type +type MockServer struct { + mock.Mock +} + +type MockServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockServer) EXPECT() *MockServer_Expecter { + return &MockServer_Expecter{mock: &_m.Mock} +} + +// AddSupportedSubprotocol provides a mock function with given fields: subProto +func (_m *MockServer) AddSupportedSubprotocol(subProto string) { + _m.Called(subProto) +} + +// MockServer_AddSupportedSubprotocol_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddSupportedSubprotocol' +type MockServer_AddSupportedSubprotocol_Call struct { + *mock.Call +} + +// AddSupportedSubprotocol is a helper method to define mock.On call +// - subProto string +func (_e *MockServer_Expecter) AddSupportedSubprotocol(subProto interface{}) *MockServer_AddSupportedSubprotocol_Call { + return &MockServer_AddSupportedSubprotocol_Call{Call: _e.mock.On("AddSupportedSubprotocol", subProto)} +} + +func (_c *MockServer_AddSupportedSubprotocol_Call) Run(run func(subProto string)) *MockServer_AddSupportedSubprotocol_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockServer_AddSupportedSubprotocol_Call) Return() *MockServer_AddSupportedSubprotocol_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_AddSupportedSubprotocol_Call) RunAndReturn(run func(string)) *MockServer_AddSupportedSubprotocol_Call { + _c.Run(run) + return _c +} + +// Addr provides a mock function with no fields +func (_m *MockServer) Addr() *net.TCPAddr { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Addr") + } + + var r0 *net.TCPAddr + if rf, ok := ret.Get(0).(func() *net.TCPAddr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*net.TCPAddr) + } + } + + return r0 +} + +// MockServer_Addr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Addr' +type MockServer_Addr_Call struct { + *mock.Call +} + +// Addr is a helper method to define mock.On call +func (_e *MockServer_Expecter) Addr() *MockServer_Addr_Call { + return &MockServer_Addr_Call{Call: _e.mock.On("Addr")} +} + +func (_c *MockServer_Addr_Call) Run(run func()) *MockServer_Addr_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockServer_Addr_Call) Return(_a0 *net.TCPAddr) *MockServer_Addr_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockServer_Addr_Call) RunAndReturn(run func() *net.TCPAddr) *MockServer_Addr_Call { + _c.Call.Return(run) + return _c +} + +// Errors provides a mock function with no fields +func (_m *MockServer) Errors() <-chan error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Errors") + } + + var r0 <-chan error + if rf, ok := ret.Get(0).(func() <-chan error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan error) + } + } + + return r0 +} + +// MockServer_Errors_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Errors' +type MockServer_Errors_Call struct { + *mock.Call +} + +// Errors is a helper method to define mock.On call +func (_e *MockServer_Expecter) Errors() *MockServer_Errors_Call { + return &MockServer_Errors_Call{Call: _e.mock.On("Errors")} +} + +func (_c *MockServer_Errors_Call) Run(run func()) *MockServer_Errors_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockServer_Errors_Call) Return(_a0 <-chan error) *MockServer_Errors_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockServer_Errors_Call) RunAndReturn(run func() <-chan error) *MockServer_Errors_Call { + _c.Call.Return(run) + return _c +} + +// GetChannel provides a mock function with given fields: websocketId +func (_m *MockServer) GetChannel(websocketId string) (ws.Channel, bool) { + ret := _m.Called(websocketId) + + if len(ret) == 0 { + panic("no return value specified for GetChannel") + } + + var r0 ws.Channel + var r1 bool + if rf, ok := ret.Get(0).(func(string) (ws.Channel, bool)); ok { + return rf(websocketId) + } + if rf, ok := ret.Get(0).(func(string) ws.Channel); ok { + r0 = rf(websocketId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ws.Channel) + } + } + + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(websocketId) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +// MockServer_GetChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannel' +type MockServer_GetChannel_Call struct { + *mock.Call +} + +// GetChannel is a helper method to define mock.On call +// - websocketId string +func (_e *MockServer_Expecter) GetChannel(websocketId interface{}) *MockServer_GetChannel_Call { + return &MockServer_GetChannel_Call{Call: _e.mock.On("GetChannel", websocketId)} +} + +func (_c *MockServer_GetChannel_Call) Run(run func(websocketId string)) *MockServer_GetChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockServer_GetChannel_Call) Return(_a0 ws.Channel, _a1 bool) *MockServer_GetChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockServer_GetChannel_Call) RunAndReturn(run func(string) (ws.Channel, bool)) *MockServer_GetChannel_Call { + _c.Call.Return(run) + return _c +} + +// SetBasicAuthHandler provides a mock function with given fields: handler +func (_m *MockServer) SetBasicAuthHandler(handler func(string, string) bool) { + _m.Called(handler) +} + +// MockServer_SetBasicAuthHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetBasicAuthHandler' +type MockServer_SetBasicAuthHandler_Call struct { + *mock.Call +} + +// SetBasicAuthHandler is a helper method to define mock.On call +// - handler func(string , string) bool +func (_e *MockServer_Expecter) SetBasicAuthHandler(handler interface{}) *MockServer_SetBasicAuthHandler_Call { + return &MockServer_SetBasicAuthHandler_Call{Call: _e.mock.On("SetBasicAuthHandler", handler)} +} + +func (_c *MockServer_SetBasicAuthHandler_Call) Run(run func(handler func(string, string) bool)) *MockServer_SetBasicAuthHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(string, string) bool)) + }) + return _c +} + +func (_c *MockServer_SetBasicAuthHandler_Call) Return() *MockServer_SetBasicAuthHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetBasicAuthHandler_Call) RunAndReturn(run func(func(string, string) bool)) *MockServer_SetBasicAuthHandler_Call { + _c.Run(run) + return _c +} + +// SetCheckClientHandler provides a mock function with given fields: handler +func (_m *MockServer) SetCheckClientHandler(handler ws.CheckClientHandler) { + _m.Called(handler) +} + +// MockServer_SetCheckClientHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetCheckClientHandler' +type MockServer_SetCheckClientHandler_Call struct { + *mock.Call +} + +// SetCheckClientHandler is a helper method to define mock.On call +// - handler ws.CheckClientHandler +func (_e *MockServer_Expecter) SetCheckClientHandler(handler interface{}) *MockServer_SetCheckClientHandler_Call { + return &MockServer_SetCheckClientHandler_Call{Call: _e.mock.On("SetCheckClientHandler", handler)} +} + +func (_c *MockServer_SetCheckClientHandler_Call) Run(run func(handler ws.CheckClientHandler)) *MockServer_SetCheckClientHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.CheckClientHandler)) + }) + return _c +} + +func (_c *MockServer_SetCheckClientHandler_Call) Return() *MockServer_SetCheckClientHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetCheckClientHandler_Call) RunAndReturn(run func(ws.CheckClientHandler)) *MockServer_SetCheckClientHandler_Call { + _c.Run(run) + return _c +} + +// SetCheckOriginHandler provides a mock function with given fields: handler +func (_m *MockServer) SetCheckOriginHandler(handler func(*http.Request) bool) { + _m.Called(handler) +} + +// MockServer_SetCheckOriginHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetCheckOriginHandler' +type MockServer_SetCheckOriginHandler_Call struct { + *mock.Call +} + +// SetCheckOriginHandler is a helper method to define mock.On call +// - handler func(*http.Request) bool +func (_e *MockServer_Expecter) SetCheckOriginHandler(handler interface{}) *MockServer_SetCheckOriginHandler_Call { + return &MockServer_SetCheckOriginHandler_Call{Call: _e.mock.On("SetCheckOriginHandler", handler)} +} + +func (_c *MockServer_SetCheckOriginHandler_Call) Run(run func(handler func(*http.Request) bool)) *MockServer_SetCheckOriginHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(*http.Request) bool)) + }) + return _c +} + +func (_c *MockServer_SetCheckOriginHandler_Call) Return() *MockServer_SetCheckOriginHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetCheckOriginHandler_Call) RunAndReturn(run func(func(*http.Request) bool)) *MockServer_SetCheckOriginHandler_Call { + _c.Run(run) + return _c +} + +// SetDisconnectedClientHandler provides a mock function with given fields: handler +func (_m *MockServer) SetDisconnectedClientHandler(handler func(ws.Channel)) { + _m.Called(handler) +} + +// MockServer_SetDisconnectedClientHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDisconnectedClientHandler' +type MockServer_SetDisconnectedClientHandler_Call struct { + *mock.Call +} + +// SetDisconnectedClientHandler is a helper method to define mock.On call +// - handler func(ws.Channel) +func (_e *MockServer_Expecter) SetDisconnectedClientHandler(handler interface{}) *MockServer_SetDisconnectedClientHandler_Call { + return &MockServer_SetDisconnectedClientHandler_Call{Call: _e.mock.On("SetDisconnectedClientHandler", handler)} +} + +func (_c *MockServer_SetDisconnectedClientHandler_Call) Run(run func(handler func(ws.Channel))) *MockServer_SetDisconnectedClientHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(ws.Channel))) + }) + return _c +} + +func (_c *MockServer_SetDisconnectedClientHandler_Call) Return() *MockServer_SetDisconnectedClientHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetDisconnectedClientHandler_Call) RunAndReturn(run func(func(ws.Channel))) *MockServer_SetDisconnectedClientHandler_Call { + _c.Run(run) + return _c +} + +// SetMessageHandler provides a mock function with given fields: handler +func (_m *MockServer) SetMessageHandler(handler ws.MessageHandler) { + _m.Called(handler) +} + +// MockServer_SetMessageHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMessageHandler' +type MockServer_SetMessageHandler_Call struct { + *mock.Call +} + +// SetMessageHandler is a helper method to define mock.On call +// - handler ws.MessageHandler +func (_e *MockServer_Expecter) SetMessageHandler(handler interface{}) *MockServer_SetMessageHandler_Call { + return &MockServer_SetMessageHandler_Call{Call: _e.mock.On("SetMessageHandler", handler)} +} + +func (_c *MockServer_SetMessageHandler_Call) Run(run func(handler ws.MessageHandler)) *MockServer_SetMessageHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.MessageHandler)) + }) + return _c +} + +func (_c *MockServer_SetMessageHandler_Call) Return() *MockServer_SetMessageHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetMessageHandler_Call) RunAndReturn(run func(ws.MessageHandler)) *MockServer_SetMessageHandler_Call { + _c.Run(run) + return _c +} + +// SetNewClientHandler provides a mock function with given fields: handler +func (_m *MockServer) SetNewClientHandler(handler ws.ConnectedHandler) { + _m.Called(handler) +} + +// MockServer_SetNewClientHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetNewClientHandler' +type MockServer_SetNewClientHandler_Call struct { + *mock.Call +} + +// SetNewClientHandler is a helper method to define mock.On call +// - handler ws.ConnectedHandler +func (_e *MockServer_Expecter) SetNewClientHandler(handler interface{}) *MockServer_SetNewClientHandler_Call { + return &MockServer_SetNewClientHandler_Call{Call: _e.mock.On("SetNewClientHandler", handler)} +} + +func (_c *MockServer_SetNewClientHandler_Call) Run(run func(handler ws.ConnectedHandler)) *MockServer_SetNewClientHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.ConnectedHandler)) + }) + return _c +} + +func (_c *MockServer_SetNewClientHandler_Call) Return() *MockServer_SetNewClientHandler_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetNewClientHandler_Call) RunAndReturn(run func(ws.ConnectedHandler)) *MockServer_SetNewClientHandler_Call { + _c.Run(run) + return _c +} + +// SetTimeoutConfig provides a mock function with given fields: config +func (_m *MockServer) SetTimeoutConfig(config ws.ServerTimeoutConfig) { + _m.Called(config) +} + +// MockServer_SetTimeoutConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTimeoutConfig' +type MockServer_SetTimeoutConfig_Call struct { + *mock.Call +} + +// SetTimeoutConfig is a helper method to define mock.On call +// - config ws.ServerTimeoutConfig +func (_e *MockServer_Expecter) SetTimeoutConfig(config interface{}) *MockServer_SetTimeoutConfig_Call { + return &MockServer_SetTimeoutConfig_Call{Call: _e.mock.On("SetTimeoutConfig", config)} +} + +func (_c *MockServer_SetTimeoutConfig_Call) Run(run func(config ws.ServerTimeoutConfig)) *MockServer_SetTimeoutConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(ws.ServerTimeoutConfig)) + }) + return _c +} + +func (_c *MockServer_SetTimeoutConfig_Call) Return() *MockServer_SetTimeoutConfig_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_SetTimeoutConfig_Call) RunAndReturn(run func(ws.ServerTimeoutConfig)) *MockServer_SetTimeoutConfig_Call { + _c.Run(run) + return _c +} + +// Start provides a mock function with given fields: port, listenPath +func (_m *MockServer) Start(port int, listenPath string) { + _m.Called(port, listenPath) +} + +// MockServer_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type MockServer_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - port int +// - listenPath string +func (_e *MockServer_Expecter) Start(port interface{}, listenPath interface{}) *MockServer_Start_Call { + return &MockServer_Start_Call{Call: _e.mock.On("Start", port, listenPath)} +} + +func (_c *MockServer_Start_Call) Run(run func(port int, listenPath string)) *MockServer_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int), args[1].(string)) + }) + return _c +} + +func (_c *MockServer_Start_Call) Return() *MockServer_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_Start_Call) RunAndReturn(run func(int, string)) *MockServer_Start_Call { + _c.Run(run) + return _c +} + +// Stop provides a mock function with no fields +func (_m *MockServer) Stop() { + _m.Called() +} + +// MockServer_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockServer_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockServer_Expecter) Stop() *MockServer_Stop_Call { + return &MockServer_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockServer_Stop_Call) Run(run func()) *MockServer_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockServer_Stop_Call) Return() *MockServer_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServer_Stop_Call) RunAndReturn(run func()) *MockServer_Stop_Call { + _c.Run(run) + return _c +} + +// StopConnection provides a mock function with given fields: id, closeError +func (_m *MockServer) StopConnection(id string, closeError websocket.CloseError) error { + ret := _m.Called(id, closeError) + + if len(ret) == 0 { + panic("no return value specified for StopConnection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, websocket.CloseError) error); ok { + r0 = rf(id, closeError) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockServer_StopConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopConnection' +type MockServer_StopConnection_Call struct { + *mock.Call +} + +// StopConnection is a helper method to define mock.On call +// - id string +// - closeError websocket.CloseError +func (_e *MockServer_Expecter) StopConnection(id interface{}, closeError interface{}) *MockServer_StopConnection_Call { + return &MockServer_StopConnection_Call{Call: _e.mock.On("StopConnection", id, closeError)} +} + +func (_c *MockServer_StopConnection_Call) Run(run func(id string, closeError websocket.CloseError)) *MockServer_StopConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(websocket.CloseError)) + }) + return _c +} + +func (_c *MockServer_StopConnection_Call) Return(_a0 error) *MockServer_StopConnection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockServer_StopConnection_Call) RunAndReturn(run func(string, websocket.CloseError) error) *MockServer_StopConnection_Call { + _c.Call.Return(run) + return _c +} + +// Write provides a mock function with given fields: webSocketId, data +func (_m *MockServer) Write(webSocketId string, data []byte) error { + ret := _m.Called(webSocketId, data) + + if len(ret) == 0 { + panic("no return value specified for Write") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, []byte) error); ok { + r0 = rf(webSocketId, data) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockServer_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write' +type MockServer_Write_Call struct { + *mock.Call +} + +// Write is a helper method to define mock.On call +// - webSocketId string +// - data []byte +func (_e *MockServer_Expecter) Write(webSocketId interface{}, data interface{}) *MockServer_Write_Call { + return &MockServer_Write_Call{Call: _e.mock.On("Write", webSocketId, data)} +} + +func (_c *MockServer_Write_Call) Run(run func(webSocketId string, data []byte)) *MockServer_Write_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].([]byte)) + }) + return _c +} + +func (_c *MockServer_Write_Call) Return(_a0 error) *MockServer_Write_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockServer_Write_Call) RunAndReturn(run func(string, []byte) error) *MockServer_Write_Call { + _c.Call.Return(run) + return _c +} + +// NewMockServer creates a new instance of MockServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockServer { + mock := &MockServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/mocks/mock_ServerOpt.go b/ws/mocks/mock_ServerOpt.go new file mode 100644 index 00000000..8c1ce686 --- /dev/null +++ b/ws/mocks/mock_ServerOpt.go @@ -0,0 +1,65 @@ +// Code generated by mockery v2.51.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MockServerOpt is an autogenerated mock type for the ServerOpt type +type MockServerOpt struct { + mock.Mock +} + +type MockServerOpt_Expecter struct { + mock *mock.Mock +} + +func (_m *MockServerOpt) EXPECT() *MockServerOpt_Expecter { + return &MockServerOpt_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: s +func (_m *MockServerOpt) Execute(s *ws.server) { + _m.Called(s) +} + +// MockServerOpt_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type MockServerOpt_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - s *ws.server +func (_e *MockServerOpt_Expecter) Execute(s interface{}) *MockServerOpt_Execute_Call { + return &MockServerOpt_Execute_Call{Call: _e.mock.On("Execute", s)} +} + +func (_c *MockServerOpt_Execute_Call) Run(run func(s *ws.server)) *MockServerOpt_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*ws.server)) + }) + return _c +} + +func (_c *MockServerOpt_Execute_Call) Return() *MockServerOpt_Execute_Call { + _c.Call.Return() + return _c +} + +func (_c *MockServerOpt_Execute_Call) RunAndReturn(run func(*ws.server)) *MockServerOpt_Execute_Call { + _c.Run(run) + return _c +} + +// NewMockServerOpt creates a new instance of MockServerOpt. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockServerOpt(t interface { + mock.TestingT + Cleanup(func()) +}) *MockServerOpt { + mock := &MockServerOpt{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/ws/network_test.go b/ws/network_test.go index 395bee1d..5cfc639a 100644 --- a/ws/network_test.go +++ b/ws/network_test.go @@ -12,8 +12,6 @@ import ( "github.com/caarlos0/env/v11" "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" toxiproxy "github.com/Shopify/toxiproxy/client" @@ -30,8 +28,8 @@ type NetworkTestSuite struct { suite.Suite proxy *toxiproxy.Proxy proxyPort int - server *Server - client *Client + server *server + client *client } func (s *NetworkTestSuite) SetupSuite() { @@ -47,9 +45,8 @@ func (s *NetworkTestSuite) SetupSuite() { if oldProxy != nil { s.Require().NoError(oldProxy.Delete()) } - p, err := client.CreateProxy("ocpp", cfg.ProxyOcppListener, cfg.ProxyOcppUpstream) - s.Require().NoError(err) + s.NoError(err) s.proxy = p } @@ -63,18 +60,20 @@ func (s *NetworkTestSuite) SetupTest() { } func (s *NetworkTestSuite) TearDownTest() { - s.server = nil - s.client = nil + if s.client != nil { + s.client.Stop() + } + if s.server != nil { + s.server.Stop() + } } func (s *NetworkTestSuite) TestClientConnectionFailed() { - t := s.T() - s.server = newWebsocketServer(t, nil) s.server.SetNewClientHandler(func(ws Channel) { - assert.Fail(t, "should not accept new clients") + s.Fail("should not accept new clients") }) go s.server.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Test client host := s.proxy.Listen @@ -85,29 +84,27 @@ func (s *NetworkTestSuite) TestClientConnectionFailed() { defer s.proxy.Enable() // Attempt connection err := s.client.Start(u.String()) - require.Error(t, err) - netError, ok := err.(*net.OpError) - require.True(t, ok) - require.NotNil(t, netError.Err) - sysError, ok := netError.Err.(*os.SyscallError) - require.True(t, ok) - assert.Equal(t, "connect", sysError.Syscall) - assert.Equal(t, syscall.ECONNREFUSED, sysError.Err) - // Cleanup - s.server.Stop() + s.Error(err) + var netError *net.OpError + ok := s.ErrorAs(err, &netError) + s.True(ok) + s.Error(netError.Err) + var sysError *os.SyscallError + ok = s.ErrorAs(netError.Err, &sysError) + s.True(ok) + s.Equal("connect", sysError.Syscall) + s.Equal(syscall.ECONNREFUSED, sysError.Err) } func (s *NetworkTestSuite) TestClientConnectionFailedTimeout() { - t := s.T() // Set timeouts for test s.client.timeoutConfig.HandshakeTimeout = 2 * time.Second // Setup - s.server = newWebsocketServer(t, nil) s.server.SetNewClientHandler(func(ws Channel) { - assert.Fail(t, "should not accept new clients") + s.Fail("should not accept new clients") }) go s.server.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Test client host := s.proxy.Listen @@ -118,151 +115,149 @@ func (s *NetworkTestSuite) TestClientConnectionFailedTimeout() { "timeout": 3000, // 3 seconds }) defer s.proxy.RemoveToxic("connectTimeout") - require.NoError(t, err) + s.NoError(err) // Attempt connection err = s.client.Start(u.String()) - require.Error(t, err) - netError, ok := err.(*net.OpError) - require.True(t, ok) - require.NotNil(t, netError.Err) - assert.True(t, strings.Contains(netError.Error(), "timeout")) - assert.True(t, netError.Timeout()) - // Cleanup - s.server.Stop() + s.Error(err) + var netError *net.OpError + ok := s.ErrorAs(err, &netError) + s.True(ok) + s.Error(netError.Err) + s.True(strings.Contains(netError.Error(), "timeout")) + s.True(netError.Timeout()) } func (s *NetworkTestSuite) TestClientAutoReconnect() { - t := s.T() // Set timeouts for test s.client.timeoutConfig.RetryBackOffWaitMinimum = 1 * time.Second s.client.timeoutConfig.RetryBackOffRandomRange = 1 // seconds // Setup - serverOnDisconnected := make(chan bool, 1) - clientOnDisconnected := make(chan bool, 1) - reconnected := make(chan bool, 1) - s.server = newWebsocketServer(t, nil) + serverOnDisconnected := make(chan struct{}, 1) + clientOnDisconnected := make(chan struct{}, 1) + reconnected := make(chan struct{}, 1) s.server.SetNewClientHandler(func(ws Channel) { - assert.NotNil(t, ws) + s.NotNil(ws) conn := s.server.connections[ws.ID()] - require.NotNil(t, conn) + s.NotNil(conn) }) s.server.SetDisconnectedClientHandler(func(ws Channel) { - serverOnDisconnected <- true + serverOnDisconnected <- struct{}{} }) go s.server.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Test bench s.client.SetDisconnectedHandler(func(err error) { - assert.NotNil(t, err) - closeError, ok := err.(*websocket.CloseError) - require.True(t, ok) - assert.Equal(t, websocket.CloseAbnormalClosure, closeError.Code) - assert.False(t, s.client.IsConnected()) - clientOnDisconnected <- true + s.Error(err) + var closeError *websocket.CloseError + ok := s.ErrorAs(err, &closeError) + s.True(ok) + s.Equal(websocket.CloseAbnormalClosure, closeError.Code) + s.False(s.client.IsConnected()) + clientOnDisconnected <- struct{}{} }) s.client.SetReconnectedHandler(func() { time.Sleep(time.Duration(s.client.timeoutConfig.RetryBackOffRandomRange)*time.Second + 50*time.Millisecond) // Make sure we reconnected after backoff - reconnected <- true + reconnected <- struct{}{} }) // Connect client host := s.proxy.Listen u := url.URL{Scheme: "ws", Host: host, Path: testPath} err := s.client.Start(u.String()) - require.Nil(t, err) + s.NoError(err) // Close all connection from server side - time.Sleep(500 * time.Millisecond) - for _, s := range s.server.connections { - err = s.connection.Close() - require.Nil(t, err) + time.Sleep(100 * time.Millisecond) + for _, c := range s.server.connections { + err = c.connection.Close() + s.NoError(err) } // Wait for disconnect to propagate result := <-serverOnDisconnected - require.True(t, result) + s.NotNil(result) result = <-clientOnDisconnected - require.True(t, result) + s.NotNil(result) start := time.Now() // Wait for reconnection result = <-reconnected elapsed := time.Since(start) - assert.True(t, result) - assert.True(t, s.client.IsConnected()) - assert.GreaterOrEqual(t, elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) + s.NotNil(result) + s.True(s.client.IsConnected()) + s.GreaterOrEqual(elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) // Cleanup s.client.SetDisconnectedHandler(func(err error) { - assert.Nil(t, err) - clientOnDisconnected <- true + s.NoError(err) + clientOnDisconnected <- struct{}{} }) s.client.Stop() result = <-clientOnDisconnected - require.True(t, result) + s.NotNil(result) s.server.Stop() } func (s *NetworkTestSuite) TestClientPongTimeout() { - t := s.T() // Set timeouts for test // Will attempt to send ping after 1 second, and server expects ping within 1.4 seconds - // Server will close connection + // server will close connection s.client.timeoutConfig.PongWait = 2 * time.Second s.client.timeoutConfig.PingPeriod = (s.client.timeoutConfig.PongWait * 5) / 10 s.client.timeoutConfig.RetryBackOffWaitMinimum = 1 * time.Second s.client.timeoutConfig.RetryBackOffWaitMinimum = 0 // remove randomness s.server.timeoutConfig.PingWait = (s.client.timeoutConfig.PongWait * 7) / 10 // Setup - serverOnDisconnected := make(chan bool, 1) - clientOnDisconnected := make(chan bool, 1) - reconnected := make(chan bool, 1) + serverOnDisconnected := make(chan struct{}, 1) + clientOnDisconnected := make(chan struct{}, 1) + reconnected := make(chan struct{}, 1) s.server.SetNewClientHandler(func(ws Channel) { - assert.NotNil(t, ws) + s.NotNil(ws) }) s.server.SetDisconnectedClientHandler(func(ws Channel) { - serverOnDisconnected <- true + serverOnDisconnected <- struct{}{} }) s.server.SetMessageHandler(func(ws Channel, data []byte) error { - assert.Fail(t, "unexpected message received") + s.Fail("unexpected message received") return fmt.Errorf("unexpected message received") }) go s.server.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Test client s.client.SetDisconnectedHandler(func(err error) { defer func() { - clientOnDisconnected <- true + clientOnDisconnected <- struct{}{} }() - require.Error(t, err) - closeError, ok := err.(*websocket.CloseError) - require.True(t, ok) - assert.Equal(t, websocket.CloseAbnormalClosure, closeError.Code) + s.Error(err) + var closeError *websocket.CloseError + ok := s.ErrorAs(err, &closeError) + s.True(ok) + s.Equal(websocket.CloseAbnormalClosure, closeError.Code) }) s.client.SetReconnectedHandler(func() { time.Sleep(50 * time.Millisecond) // Make sure we reconnected after backoff - reconnected <- true + reconnected <- struct{}{} }) host := s.proxy.Listen u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Attempt connection err := s.client.Start(u.String()) - require.NoError(t, err) + s.NoError(err) // Slow upstream network -> ping won't get through and server-side close will be triggered _, err = s.proxy.AddToxic("readTimeout", "timeout", "upstream", 1, toxiproxy.Attributes{ "timeout": 5000, // 5 seconds }) - require.NoError(t, err) + s.NoError(err) // Attempt to send message result := <-clientOnDisconnected - require.True(t, result) + s.NotNil(result) result = <-serverOnDisconnected - require.True(t, result) + s.NotNil(result) // Reconnect time starts _ = s.proxy.RemoveToxic("readTimeout") startTimeout := time.Now() result = <-reconnected - require.True(t, result) + s.NotNil(result) elapsed := time.Since(startTimeout) - assert.GreaterOrEqual(t, elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) + s.GreaterOrEqual(elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) // Cleanup s.client.SetDisconnectedHandler(nil) s.client.Stop() @@ -270,7 +265,6 @@ func (s *NetworkTestSuite) TestClientPongTimeout() { } func (s *NetworkTestSuite) TestClientReadTimeout() { - t := s.T() // Set timeouts for test s.client.timeoutConfig.PongWait = 2 * time.Second s.client.timeoutConfig.PingPeriod = (s.client.timeoutConfig.PongWait * 7) / 10 @@ -278,63 +272,125 @@ func (s *NetworkTestSuite) TestClientReadTimeout() { s.client.timeoutConfig.RetryBackOffRandomRange = 0 // remove randomness s.server.timeoutConfig.PingWait = s.client.timeoutConfig.PongWait // Setup - serverOnDisconnected := make(chan bool, 1) - clientOnDisconnected := make(chan bool, 1) - reconnected := make(chan bool, 1) + serverOnDisconnected := make(chan struct{}, 1) + clientOnDisconnected := make(chan struct{}, 1) + reconnected := make(chan struct{}, 1) + s.server.SetNewClientHandler(func(ws Channel) { + s.NotNil(ws) + }) + s.server.SetDisconnectedClientHandler(func(ws Channel) { + serverOnDisconnected <- struct{}{} + }) + s.server.SetMessageHandler(func(ws Channel, data []byte) error { + s.Fail("unexpected message received") + return fmt.Errorf("unexpected message received") + }) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + + // Test client + s.client.SetDisconnectedHandler(func(err error) { + defer func() { + clientOnDisconnected <- struct{}{} + }() + s.Error(err) + errMsg := err.Error() + s.Contains(errMsg, "timeout") + }) + s.client.SetReconnectedHandler(func() { + time.Sleep(50 * time.Millisecond) // Make sure we reconnected after backoff + reconnected <- struct{}{} + }) + host := s.proxy.Listen + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + + // Attempt connection + err := s.client.Start(u.String()) + s.NoError(err) + // Slow down network. Ping will be received but pong won't go through + _, err = s.proxy.AddToxic("writeTimeout", "timeout", "downstream", 1, toxiproxy.Attributes{ + "timeout": 5000, // 5 seconds + }) + s.NoError(err) + // Attempt to send message + result := <-serverOnDisconnected + s.NotNil(result) + result = <-clientOnDisconnected + s.NotNil(result) + // Reconnect time starts + s.proxy.RemoveToxic("writeTimeout") + startTimeout := time.Now() + result = <-reconnected + s.NotNil(result) + elapsed := time.Since(startTimeout) + s.GreaterOrEqual(elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) + // Cleanup + s.client.SetDisconnectedHandler(nil) + s.client.Stop() + s.server.Stop() +} + +func (s *NetworkTestSuite) TestServerReadTimeout() { + // Set timeouts for test + s.client.timeoutConfig.PongWait = 2 * time.Second + s.client.timeoutConfig.PingPeriod = 3 * time.Second + s.client.timeoutConfig.RetryBackOffWaitMinimum = 1 * time.Second + s.client.timeoutConfig.RetryBackOffRandomRange = 0 // remove randomness + s.server.timeoutConfig.PingWait = s.client.timeoutConfig.PongWait + // Setup + serverOnDisconnected := make(chan struct{}, 1) + clientOnDisconnected := make(chan struct{}, 1) + reconnected := make(chan struct{}, 1) s.server.SetNewClientHandler(func(ws Channel) { - assert.NotNil(t, ws) + s.NotNil(ws) }) s.server.SetDisconnectedClientHandler(func(ws Channel) { - serverOnDisconnected <- true + serverOnDisconnected <- struct{}{} }) s.server.SetMessageHandler(func(ws Channel, data []byte) error { - assert.Fail(t, "unexpected message received") + s.Fail("unexpected message received") return fmt.Errorf("unexpected message received") }) go s.server.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Test client s.client.SetDisconnectedHandler(func(err error) { defer func() { - clientOnDisconnected <- true + clientOnDisconnected <- struct{}{} }() - require.Error(t, err) + s.Error(err) errMsg := err.Error() - c := strings.Contains(errMsg, "timeout") - if !c { - fmt.Println(errMsg) - } - assert.True(t, c) + s.Contains(errMsg, "timeout") }) s.client.SetReconnectedHandler(func() { time.Sleep(50 * time.Millisecond) // Make sure we reconnected after backoff - reconnected <- true + reconnected <- struct{}{} }) host := s.proxy.Listen u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Attempt connection err := s.client.Start(u.String()) - require.NoError(t, err) + s.NoError(err) + // Send me // Slow down network. Ping will be received but pong won't go through _, err = s.proxy.AddToxic("writeTimeout", "timeout", "downstream", 1, toxiproxy.Attributes{ "timeout": 5000, // 5 seconds }) - require.NoError(t, err) + s.NoError(err) // Attempt to send message - require.NoError(t, err) result := <-serverOnDisconnected - require.True(t, result) + s.NotNil(result) result = <-clientOnDisconnected - require.True(t, result) + s.NotNil(result) // Reconnect time starts s.proxy.RemoveToxic("writeTimeout") startTimeout := time.Now() result = <-reconnected - require.True(t, result) + s.NotNil(result) elapsed := time.Since(startTimeout) - assert.GreaterOrEqual(t, elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) + s.GreaterOrEqual(elapsed.Milliseconds(), s.client.timeoutConfig.RetryBackOffWaitMinimum.Milliseconds()) // Cleanup s.client.SetDisconnectedHandler(nil) s.client.Stop() diff --git a/ws/server.go b/ws/server.go new file mode 100644 index 00000000..593a7c47 --- /dev/null +++ b/ws/server.go @@ -0,0 +1,466 @@ +package ws + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "path" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" +) + +// ---------------------- SERVER ---------------------- + +type CheckClientHandler func(id string, r *http.Request) bool + +// Server defines a websocket server, which passively listens for incoming connections on ws or wss protocol. +// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks. +// +// To create a new ws server, use: +// +// server := NewServer() +// +// If you need a server with TLS support, pass the following option: +// +// server := NewServer(WithServerTLSConfig("cert.pem", "privateKey.pem", nil)) +// +// To support client basic authentication, use: +// +// server.SetBasicAuthHandler(func (user, pass) bool { +// ok := authenticate(user, pass) // ... check for user and pass correctness +// return ok +// }) +// +// To specify supported sub-protocols, use: +// +// server.AddSupportedSubprotocol("ocpp1.6") +// +// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method. +// +// Using Start and Stop you can respectively start and stop listening for incoming client websocket connections. +// +// To be notified of new and terminated connections, +// refer to SetNewClientHandler and SetDisconnectedClientHandler functions. +// +// To receive incoming messages, you will need to set your own handler using SetMessageHandler. +// To write data on the open socket, simply call the Write function. +type Server interface { + // Starts and runs the websocket server on a specific port and URL. + // After start, incoming connections and messages are handled automatically, so no explicit read operation is required. + // + // The functions blocks forever, hence it is suggested to invoke it in a goroutine, if the caller thread needs to perform other work, e.g.: + // go server.Start(8887, "/ws/{id}") + // doStuffOnMainThread() + // ... + // + // To stop a running server, call the Stop function. + Start(port int, listenPath string) + // Shuts down a running websocket server. + // All open channels will be forcefully closed, and the previously called Start function will return. + Stop() + // Closes a specific websocket connection. + StopConnection(id string, closeError websocket.CloseError) error + // Errors returns a channel for error messages. If it doesn't exist it es created. + // The channel is closed by the server when stopped. + Errors() <-chan error + // Sets a callback function for all incoming messages. + // The callbacks accept a Channel and the received data. + // It is up to the callback receiver, to check the identifier of the channel, to determine the source of the message. + SetMessageHandler(handler MessageHandler) + // SetNewClientHandler sets a callback function for all new incoming client connections. + // It is recommended to store a reference to the Channel in the received entity, so that the Channel may be recognized later on. + // + // The callback is invoked after a connection was established and upgraded successfully. + // If custom checks need to be run beforehand, refer to SetCheckClientHandler. + SetNewClientHandler(handler ConnectedHandler) + // Sets a callback function for all client disconnection events. + // Once a client is disconnected, it is not possible to read/write on the respective Channel any longer. + SetDisconnectedClientHandler(handler func(ws Channel)) + // Set custom timeout configuration parameters. If not passed, a default ServerTimeoutConfig struct will be used. + // + // This function must be called before starting the server, otherwise it may lead to unexpected behavior. + SetTimeoutConfig(config ServerTimeoutConfig) + // Write sends a message on a specific Channel, identifier by the webSocketId parameter. + // If the passed ID is invalid, an error is returned. + // + // The data is queued and will be sent asynchronously in the background. + Write(webSocketId string, data []byte) error + // AddSupportedSubprotocol adds support for a specified subprotocol. + // This is recommended in order to communicate the capabilities to the client during the handshake. + // If left empty, any subprotocol will be accepted. + // + // Duplicates will be removed automatically. + AddSupportedSubprotocol(subProto string) + // SetBasicAuthHandler enables HTTP Basic Authentication and requires clients to pass credentials. + // The handler function is called whenever a new client attempts to connect, to check for credentials correctness. + // The handler must return true if the credentials were correct, false otherwise. + SetBasicAuthHandler(handler func(username string, password string) bool) + // SetCheckOriginHandler sets a handler for incoming websocket connections, allowing to perform + // custom cross-origin checks. + // + // By default, if the Origin header is present in the request, and the Origin host is not equal + // to the Host request header, the websocket handshake fails. + SetCheckOriginHandler(handler func(r *http.Request) bool) + // SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform + // custom client connection checks. + // The handler is executed before any connection upgrade and allows optionally returning a custom + // configuration for the web socket that will be created. + // + // Changes to the http request at runtime may lead to undefined behavior. + SetCheckClientHandler(handler CheckClientHandler) + // Addr gives the address on which the server is listening, useful if, for + // example, the port is system-defined (set to 0). + Addr() *net.TCPAddr + // GetChannel retrieves an active Channel connection by its unique identifier. + // If a connection with the given ID exists, it returns the corresponding webSocket instance. + // If no connection is found with the specified ID, it returns nil and a false flag. + GetChannel(websocketId string) (Channel, bool) +} + +// Default implementation of a Websocket server. +// +// Use the NewServer function to create a new server. +type server struct { + connections map[string]*webSocket + httpServer *http.Server + messageHandler func(ws Channel, data []byte) error + checkClientHandler CheckClientHandler + newClientHandler func(ws Channel) + disconnectedHandler func(ws Channel) + basicAuthHandler func(username string, password string) bool + tlsCertificatePath string + tlsCertificateKey string + timeoutConfig ServerTimeoutConfig + upgrader websocket.Upgrader + errC chan error + connMutex sync.RWMutex + addr *net.TCPAddr + httpHandler *mux.Router +} + +// ServerOpt is a function that can be used to set options on a server during creation. +type ServerOpt func(s *server) + +// WithServerTLSConfig sets the TLS configuration for the server. +// If the passed tlsConfig is nil, the client will not use TLS. +func WithServerTLSConfig(certificatePath string, certificateKey string, tlsConfig *tls.Config) ServerOpt { + return func(s *server) { + s.tlsCertificatePath = certificatePath + s.tlsCertificateKey = certificateKey + if tlsConfig != nil { + s.httpServer.TLSConfig = tlsConfig + } + } +} + +// NewServer Creates a new websocket server. +// +// Additional options may be added using the AddOption function. +// +// By default, the websockets are not secure, and the server will not perform any client certificate verification. +// +// To add TLS support to the server, a valid server certificate path and key must be passed. +// To also add support for client certificate verification, a valid TLSConfig needs to be configured. +// For example: +// +// tlsConfig := &tls.Config{ +// ClientAuth: tls.RequireAndVerifyClientCert, +// ClientCAs: clientCAs, +// } +// server := ws.NewServer(ws.WithServerTLSConfig("cert.pem", "privateKey.pem", tlsConfig)) +// +// When TLS is correctly configured, the server will automatically use it for all created websocket channels. +func NewServer(opts ...ServerOpt) Server { + router := mux.NewRouter() + s := &server{ + httpServer: &http.Server{}, + timeoutConfig: NewServerTimeoutConfig(), + upgrader: websocket.Upgrader{Subprotocols: []string{}}, + httpHandler: router, + } + for _, o := range opts { + o(s) + } + return s +} + +func (s *server) SetMessageHandler(handler MessageHandler) { + s.messageHandler = handler +} + +func (s *server) SetCheckClientHandler(handler CheckClientHandler) { + s.checkClientHandler = handler +} + +func (s *server) SetNewClientHandler(handler ConnectedHandler) { + s.newClientHandler = handler +} + +func (s *server) SetDisconnectedClientHandler(handler func(ws Channel)) { + s.disconnectedHandler = handler +} + +func (s *server) SetTimeoutConfig(config ServerTimeoutConfig) { + s.timeoutConfig = config +} + +func (s *server) AddSupportedSubprotocol(subProto string) { + for _, sub := range s.upgrader.Subprotocols { + if sub == subProto { + // Don't add duplicates + return + } + } + s.upgrader.Subprotocols = append(s.upgrader.Subprotocols, subProto) +} + +func (s *server) SetBasicAuthHandler(handler func(username string, password string) bool) { + s.basicAuthHandler = handler +} + +func (s *server) SetCheckOriginHandler(handler func(r *http.Request) bool) { + s.upgrader.CheckOrigin = handler +} + +func (s *server) error(err error) { + log.Error(err) + if s.errC != nil { + s.errC <- err + } +} + +func (s *server) Errors() <-chan error { + if s.errC == nil { + s.errC = make(chan error, 1) + } + return s.errC +} + +func (s *server) Addr() *net.TCPAddr { + return s.addr +} + +func (s *server) AddHttpHandler(listenPath string, handler func(w http.ResponseWriter, r *http.Request)) { + s.httpHandler.HandleFunc(listenPath, handler) +} + +func (s *server) Start(port int, listenPath string) { + s.connMutex.Lock() + s.connections = make(map[string]*webSocket) + s.connMutex.Unlock() + + if s.httpServer == nil { + s.httpServer = &http.Server{} + } + + addr := fmt.Sprintf(":%v", port) + s.httpServer.Addr = addr + + s.AddHttpHandler(listenPath, func(w http.ResponseWriter, r *http.Request) { + s.wsHandler(w, r) + }) + s.httpServer.Handler = s.httpHandler + + ln, err := net.Listen("tcp", addr) + if err != nil { + s.error(fmt.Errorf("failed to listen: %w", err)) + return + } + + s.addr = ln.Addr().(*net.TCPAddr) + + defer ln.Close() + + log.Infof("listening on tcp network %v", addr) + s.httpServer.RegisterOnShutdown(s.stopConnections) + if s.tlsCertificatePath != "" && s.tlsCertificateKey != "" { + err = s.httpServer.ServeTLS(ln, s.tlsCertificatePath, s.tlsCertificateKey) + } else { + err = s.httpServer.Serve(ln) + } + + if !errors.Is(err, http.ErrServerClosed) { + s.error(fmt.Errorf("failed to listen: %w", err)) + } +} + +func (s *server) Stop() { + log.Info("stopping websocket server") + err := s.httpServer.Shutdown(context.TODO()) + if err != nil { + s.error(fmt.Errorf("shutdown failed: %w", err)) + } + + if s.errC != nil { + close(s.errC) + s.errC = nil + } +} + +func (s *server) StopConnection(id string, closeError websocket.CloseError) error { + s.connMutex.RLock() + w, ok := s.connections[id] + s.connMutex.RUnlock() + + if !ok { + return fmt.Errorf("couldn't stop websocket connection. No connection with id %s is open", id) + } + log.Debugf("sending stop signal for websocket %s", w.ID()) + return w.Close(closeError) +} + +func (s *server) GetChannel(websocketId string) (Channel, bool) { + s.connMutex.RLock() + defer s.connMutex.RUnlock() + c, ok := s.connections[websocketId] + return c, ok +} + +func (s *server) stopConnections() { + s.connMutex.RLock() + defer s.connMutex.RUnlock() + for _, conn := range s.connections { + _ = conn.Close(websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}) + } +} + +func (s *server) Write(webSocketId string, data []byte) error { + s.connMutex.RLock() + defer s.connMutex.RUnlock() + w, ok := s.connections[webSocketId] + if !ok { + return fmt.Errorf("couldn't write to websocket. No socket with id %v is open", webSocketId) + } + log.Debugf("queuing data for websocket %s", webSocketId) + return w.Write(data) +} + +func (s *server) wsHandler(w http.ResponseWriter, r *http.Request) { + responseHeader := http.Header{} + url := r.URL + id := path.Base(url.Path) + log.Debugf("handling new connection for %s from %s", id, r.RemoteAddr) + // Negotiate sub-protocol + clientSubProtocols := websocket.Subprotocols(r) + negotiatedSubProtocol := "" +out: + for _, requestedProto := range clientSubProtocols { + if len(s.upgrader.Subprotocols) == 0 { + // All subProtocols are accepted, pick first + negotiatedSubProtocol = requestedProto + break + } + // Check if requested suprotocol is supported by server + for _, supportedProto := range s.upgrader.Subprotocols { + if requestedProto == supportedProto { + negotiatedSubProtocol = requestedProto + break out + } + } + } + if negotiatedSubProtocol != "" { + responseHeader.Add("Sec-WebSocket-Protocol", negotiatedSubProtocol) + } + // Handle client authentication + if s.basicAuthHandler != nil { + username, password, ok := r.BasicAuth() + if ok { + ok = s.basicAuthHandler(username, password) + } + if !ok { + s.error(fmt.Errorf("basic auth failed: credentials invalid")) + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } + // Custom client checks + if s.checkClientHandler != nil { + ok := s.checkClientHandler(id, r) + if !ok { + s.error(fmt.Errorf("client validation: invalid client")) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } + + // Upgrade websocket + conn, err := s.upgrader.Upgrade(w, r, responseHeader) + if err != nil { + s.error(fmt.Errorf("upgrade failed: %w", err)) + return + } + + log.Debugf("upgraded websocket connection for %s from %s", id, conn.RemoteAddr().String()) + // If unsupported sub-protocol, terminate the connection immediately + if negotiatedSubProtocol == "" { + s.error(fmt.Errorf("unsupported subprotocols %v for new client %v (%v)", clientSubProtocols, id, r.RemoteAddr)) + _ = conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseProtocolError, "invalid or unsupported subprotocol"), + time.Now().Add(s.timeoutConfig.WriteWait)) + _ = conn.Close() + return + } + // Check whether client exists + s.connMutex.Lock() + // There is already a connection with the same ID. Close the new one immediately with a PolicyViolation. + if _, exists := s.connections[id]; exists { + s.connMutex.Unlock() + s.error(fmt.Errorf("client %s already exists, closing duplicate client", id)) + _ = conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"), + time.Now().Add(s.timeoutConfig.WriteWait)) + _ = conn.Close() + return + } + // Create web socket for client, state is automatically set to connected + ws := newWebSocket( + id, + conn, + r.TLS, + NewDefaultWebSocketConfig( + s.timeoutConfig.WriteWait, + s.timeoutConfig.PingWait, + s.timeoutConfig.PingPeriod, + s.timeoutConfig.PongWait), + s.handleMessage, + s.handleDisconnect, + func(_ Channel, err error) { + s.error(err) + }, + ) + // Add new client + s.connections[ws.id] = ws + s.connMutex.Unlock() + // Start reader and write routine + ws.run() + if s.newClientHandler != nil { + var channel Channel = ws + s.newClientHandler(channel) + } +} + +// --------- Internal callbacks webSocket -> server --------- +func (s *server) handleMessage(w Channel, data []byte) error { + if s.messageHandler != nil { + return s.messageHandler(w, data) + } + return fmt.Errorf("no message handler set") +} + +func (s *server) handleDisconnect(w Channel, _ error) { + // server never attempts to auto-reconnect to client. Resources are simply freed up + s.connMutex.Lock() + delete(s.connections, w.ID()) + s.connMutex.Unlock() + log.Infof("closed connection to %s", w.ID()) + if s.disconnectedHandler != nil { + s.disconnectedHandler(w) + } +} diff --git a/ws/websocket.go b/ws/websocket.go index 3cbe25a6..0ec621ba 100644 --- a/ws/websocket.go +++ b/ws/websocket.go @@ -1,24 +1,16 @@ // The package is a wrapper around gorilla websockets, // aimed at simplifying the creation and usage of a websocket client/server. // -// Check the Client and Server structure to get started. +// Check the client and server structure to get started. package ws import ( - "context" "crypto/tls" - "encoding/base64" "fmt" - "io" - "math/rand" "net" - "net/http" - "net/url" - "path" "sync" "time" - "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/lorenzodonini/ocpp-go/logging" ) @@ -63,33 +55,41 @@ func SetLogger(logger logging.Logger) { log = logger } -// Config contains optional configuration parameters for a websocket server. +// ServerTimeoutConfig contains optional configuration parameters for a websocket server. // Setting the parameter allows to define custom timeout intervals for websocket network operations. // // To set a custom configuration, refer to the server's SetTimeoutConfig method. // If no configuration is passed, a default configuration is generated via the NewServerTimeoutConfig function. type ServerTimeoutConfig struct { - WriteWait time.Duration - PingWait time.Duration + WriteWait time.Duration // The timeout for network write operations. After a timeout, the connection is closed. + PingWait time.Duration // The timeout for waiting for a ping from the client. After a timeout, the connection is closed. + PingPeriod time.Duration // The interval for sending ping messages to a client. If set to 0, no pings are sent. + PongWait time.Duration // The timeout for waiting for a pong from the server. After a timeout, the connection is closed. Needs to be set, if server is configured to send ping messages. } // NewServerTimeoutConfig creates a default timeout configuration for a websocket endpoint. +// In the default configuration, server-side ping messages are disabled. // // You may change fields arbitrarily and pass the struct to a SetTimeoutConfig method. func NewServerTimeoutConfig() ServerTimeoutConfig { - return ServerTimeoutConfig{WriteWait: defaultWriteWait, PingWait: defaultPingWait} + return ServerTimeoutConfig{ + WriteWait: defaultWriteWait, + PingWait: defaultPingWait, + PingPeriod: 0, + PongWait: 0, + } } -// Config contains optional configuration parameters for a websocket client. +// ClientTimeoutConfig contains optional configuration parameters for a websocket client. // Setting the parameter allows to define custom timeout intervals for websocket network operations. // // To set a custom configuration, refer to the client's SetTimeoutConfig method. // If no configuration is passed, a default configuration is generated via the NewClientTimeoutConfig function. type ClientTimeoutConfig struct { - WriteWait time.Duration - HandshakeTimeout time.Duration - PongWait time.Duration - PingPeriod time.Duration + WriteWait time.Duration // The timeout for network write operations. After a timeout, the connection is closed. + HandshakeTimeout time.Duration // The timeout for the initial handshake to complete. + PongWait time.Duration // The timeout for waiting for a pong from the server. After a timeout, the connection is closed. Needs to be set, if client is configured to send ping messages. + PingPeriod time.Duration // The interval for sending ping messages to a server. If set to 0, no pings are sent. RetryBackOffRepeatTimes int RetryBackOffRandomRange int RetryBackOffWaitMinimum time.Duration @@ -110,1052 +110,425 @@ func NewClientTimeoutConfig() ClientTimeoutConfig { } } -// Channel represents a bi-directional communication channel, which provides at least a unique ID. -type Channel interface { - ID() string - RemoteAddr() net.Addr - TLSConnectionState() *tls.ConnectionState +// Wraps a time.Ticker instance to provide a nullable ticker option. +// If no real ticker is instantiated, the struct will run/return no-ops. +type optTicker struct { + c chan time.Time + ticker *time.Ticker } -// WebSocket is a wrapper for a single websocket channel. -// The connection itself is provided by the gorilla websocket package. -// -// Don't use a websocket directly, but refer to WsServer and WsClient. -type WebSocket struct { - connection *websocket.Conn - id string - outQueue chan []byte - closeC chan websocket.CloseError // used to gracefully close a websocket connection. - forceCloseC chan error // used by the readPump to notify a forcefully closed connection to the writePump. - pingMessage chan []byte - tlsConnectionState *tls.ConnectionState -} - -// Retrieves the unique Identifier of the websocket (typically, the URL suffix). -func (websocket *WebSocket) ID() string { - return websocket.id -} - -// Returns the address of the remote peer. -func (websocket *WebSocket) RemoteAddr() net.Addr { - return websocket.connection.RemoteAddr() +func newOptTicker(pingCfg *PingConfig) optTicker { + if pingCfg != nil && pingCfg.PingPeriod > 0 { + // Create regular ticker + return optTicker{ + ticker: time.NewTicker(pingCfg.PingPeriod), + } + } + // Ticker shall be dummy, as it doesn't trigger any actual events + return optTicker{ + c: make(chan time.Time, 1), + } } -// Returns the TLS connection state of the connection, if any. -func (websocket *WebSocket) TLSConnectionState() *tls.ConnectionState { - return websocket.tlsConnectionState +func (o optTicker) T() <-chan time.Time { + if o.ticker != nil { + return o.ticker.C + } + return o.c } -// ConnectionError is a websocket -type HttpConnectionError struct { - Message string - HttpStatus string - HttpCode int - Details string +func (o optTicker) Stop() { + if o.ticker != nil { + o.ticker.Stop() + } } -func (e HttpConnectionError) Error() string { - return fmt.Sprintf("%v, http status: %v", e.Message, e.HttpStatus) +// Channel represents a bi-directional IP-based communication channel, which provides at least a unique ID. +type Channel interface { + // ID returns the unique identifier of the client, which identifies this unique channel. + ID() string + // RemoteAddr returns the remote IP network address of the connected peer. + RemoteAddr() net.Addr + // TLSConnectionState returns information about the active TLS connection, if any. + TLSConnectionState() *tls.ConnectionState + // IsConnected returns true if the connection to the peer is active, false if it was closed already. + IsConnected() bool } -// ---------------------- SERVER ---------------------- - -type CheckClientHandler func(id string, r *http.Request) bool - -// WsServer defines a websocket server, which passively listens for incoming connections on ws or wss protocol. -// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks. -// -// To create a new ws server, use: -// -// server := NewServer() -// -// If you need a TLS ws server instead, use: -// -// server := NewTLSServer("cert.pem", "privateKey.pem") -// -// To support client basic authentication, use: -// -// server.SetBasicAuthHandler(func (user, pass) bool { -// ok := authenticate(user, pass) // ... check for user and pass correctness -// return ok -// }) -// -// To specify supported sub-protocols, use: -// -// server.AddSupportedSubprotocol("ocpp1.6") -// -// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method. -// -// Using Start and Stop you can respectively start and stop listening for incoming client websocket connections. -// -// To be notified of new and terminated connections, -// refer to SetNewClientHandler and SetDisconnectedClientHandler functions. -// -// To receive incoming messages, you will need to set your own handler using SetMessageHandler. -// To write data on the open socket, simply call the Write function. -type WsServer interface { - // Starts and runs the websocket server on a specific port and URL. - // After start, incoming connections and messages are handled automatically, so no explicit read operation is required. - // - // The functions blocks forever, hence it is suggested to invoke it in a goroutine, if the caller thread needs to perform other work, e.g.: - // go server.Start(8887, "/ws/{id}") - // doStuffOnMainThread() - // ... - // - // To stop a running server, call the Stop function. - Start(port int, listenPath string) - // Shuts down a running websocket server. - // All open channels will be forcefully closed, and the previously called Start function will return. - Stop() - // Closes a specific websocket connection. - StopConnection(id string, closeError websocket.CloseError) error - // Errors returns a channel for error messages. If it doesn't exist it es created. - // The channel is closed by the server when stopped. - Errors() <-chan error - // Sets a callback function for all incoming messages. - // The callbacks accept a Channel and the received data. - // It is up to the callback receiver, to check the identifier of the channel, to determine the source of the message. - SetMessageHandler(handler func(ws Channel, data []byte) error) - // Sets a callback function for all new incoming client connections. - // It is recommended to store a reference to the Channel in the received entity, so that the Channel may be recognized later on. - SetNewClientHandler(handler func(ws Channel)) - // Sets a callback function for all client disconnection events. - // Once a client is disconnected, it is not possible to read/write on the respective Channel any longer. - SetDisconnectedClientHandler(handler func(ws Channel)) - // Set custom timeout configuration parameters. If not passed, a default ServerTimeoutConfig struct will be used. - // - // This function must be called before starting the server, otherwise it may lead to unexpected behavior. - SetTimeoutConfig(config ServerTimeoutConfig) - // Sends a message on a specific Channel, identifier by the webSocketId parameter. - // If the passed ID is invalid, an error is returned. - // - // The data is queued and will be sent asynchronously in the background. - Write(webSocketId string, data []byte) error - // Adds support for a specified subprotocol. - // This is recommended in order to communicate the capabilities to the client during the handshake. - // If left empty, any subprotocol will be accepted. +// WebSocketConfig is a utility config struct for a single webSocket. +// By default, it inherits values from respective the ClientTimeoutConfig or ServerTimeoutConfig. +// However, during creation, some fields may be overridden and customized on a websocket-basis. +type WebSocketConfig struct { + // The timeout for network write operations. + // After a timeout, the connection is closed. + WriteWait time.Duration + // The timeout for waiting for a message from the connected peer. + // After a timeout, the connection is closed. + // If ReadWait is zero, the websocket will not time out on read operations. // - // Duplicates will be removed automatically. - AddSupportedSubprotocol(subProto string) - // SetBasicAuthHandler enables HTTP Basic Authentication and requires clients to pass credentials. - // The handler function is called whenever a new client attempts to connect, to check for credentials correctness. - // The handler must return true if the credentials were correct, false otherwise. - SetBasicAuthHandler(handler func(username string, password string) bool) - // SetCheckOriginHandler sets a handler for incoming websocket connections, allowing to perform - // custom cross-origin checks. + // Depending on the configuration, the websocket will either wait for incoming pings + // or send pings to the connected peer. // - // By default, if the Origin header is present in the request, and the Origin host is not equal - // to the Host request header, the websocket handshake fails. - SetCheckOriginHandler(handler func(r *http.Request) bool) - // SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform - // custom client connection checks. - SetCheckClientHandler(handler func(id string, r *http.Request) bool) - // Addr gives the address on which the server is listening, useful if, for - // example, the port is system-defined (set to 0). - Addr() *net.TCPAddr - // Connections retrieves a WebSocket connection by its unique identifier. - // If a connection with the given ID exists, it returns the corresponding WebSocket instance. - // If no connection is found with the specified ID, it returns nil. - Connections(websocketId string) *WebSocket -} - -// Default implementation of a Websocket server. -// -// Use the NewServer or NewTLSServer functions to create a new server. -type Server struct { - connections map[string]*WebSocket - httpServer *http.Server - messageHandler func(ws Channel, data []byte) error - checkClientHandler func(id string, r *http.Request) bool - newClientHandler func(ws Channel) - disconnectedHandler func(ws Channel) - basicAuthHandler func(username string, password string) bool - tlsCertificatePath string - tlsCertificateKey string - timeoutConfig ServerTimeoutConfig - upgrader websocket.Upgrader - errC chan error - connMutex sync.RWMutex - addr *net.TCPAddr - httpHandler *mux.Router -} - -// Creates a new simple websocket server (the websockets are not secured). -func NewServer() *Server { - router := mux.NewRouter() - return &Server{ - httpServer: &http.Server{}, - timeoutConfig: NewServerTimeoutConfig(), - upgrader: websocket.Upgrader{Subprotocols: []string{}}, - httpHandler: router, + // If PingConfig is set (i.e. the websocket is configured to send ping messages), + // the ReadWait value should be omitted. + // If provided, the websocket will accept ping messages, but the read timeout + // configuration from the PingConfig will be prioritized. + ReadWait time.Duration + // Optional configuration for ping operations. If omitted, the websocket will not send any pings. + PingConfig *PingConfig + // Optional logger for the websocket. If omitted, the global logger is used. + Logger logging.Logger +} + +// PingConfig contains optional configuration parameters for websockets sending ping operations. +type PingConfig struct { + PingPeriod time.Duration // The interval for sending ping messages to the connected peer. + PongWait time.Duration // The timeout for waiting for a pong from the connected peer. After a timeout, the connection is closed. +} + +// NewDefaultWebSocketConfig creates a new websocket config struct with the passed values. +// If sendPing is set, the websocket will be configured to send out periodic pings. +// +// No custom configuration functions are run. Overrides need to be applied externally. +func NewDefaultWebSocketConfig( + writeWait time.Duration, + readWait time.Duration, + pingPeriod time.Duration, + pongWait time.Duration) WebSocketConfig { + var pingCfg *PingConfig + if pingPeriod > 0 { + pingCfg = &PingConfig{ + PingPeriod: pingPeriod, + PongWait: pongWait, + } } -} - -// NewTLSServer creates a new secure websocket server. All created websocket channels will use TLS. -// -// You need to pass a filepath to the server TLS certificate and key. -// -// It is recommended to pass a valid TLSConfig for the server to use. -// For example to require client certificate verification: -// -// tlsConfig := &tls.Config{ -// ClientAuth: tls.RequireAndVerifyClientCert, -// ClientCAs: clientCAs, -// } -// -// If no tlsConfig parameter is passed, the server will by default -// not perform any client certificate verification. -func NewTLSServer(certificatePath string, certificateKey string, tlsConfig *tls.Config) *Server { - router := mux.NewRouter() - return &Server{ - tlsCertificatePath: certificatePath, - tlsCertificateKey: certificateKey, - httpServer: &http.Server{ - TLSConfig: tlsConfig, - }, - timeoutConfig: NewServerTimeoutConfig(), - upgrader: websocket.Upgrader{Subprotocols: []string{}}, - httpHandler: router, + return WebSocketConfig{ + WriteWait: writeWait, + ReadWait: readWait, + PingConfig: pingCfg, + Logger: log, } } -func (server *Server) SetMessageHandler(handler func(ws Channel, data []byte) error) { - server.messageHandler = handler -} - -func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) bool) { - server.checkClientHandler = handler -} - -func (server *Server) SetNewClientHandler(handler func(ws Channel)) { - server.newClientHandler = handler -} +type MessageHandler func(c Channel, data []byte) error +type ConnectedHandler func(c Channel) +type DisconnectedHandler func(c Channel, err error) +type ErrorHandler func(c Channel, err error) -func (server *Server) SetDisconnectedClientHandler(handler func(ws Channel)) { - server.disconnectedHandler = handler +type message struct { + typ int + data []byte } -func (server *Server) SetTimeoutConfig(config ServerTimeoutConfig) { - server.timeoutConfig = config +// webSocket is a wrapper for a single websocket channel. +// The connection itself is provided by the gorilla websocket package. +// +// Don't use a websocket directly, but refer to Server and Client. +type webSocket struct { + connection *websocket.Conn + mutex sync.RWMutex + id string + outQueue chan message + pingC chan []byte + closeC chan websocket.CloseError // used to gracefully close a websocket connection. + forceCloseC chan error // used by the readPump to notify a forcefully closed connection to the writePump. + tlsConnectionState *tls.ConnectionState + cfg WebSocketConfig + log logging.Logger + onClosed DisconnectedHandler + onError ErrorHandler + onMessage MessageHandler } -func (server *Server) AddSupportedSubprotocol(subProto string) { - for _, sub := range server.upgrader.Subprotocols { - if sub == subProto { - // Don't add duplicates - return - } +func newWebSocket(id string, conn *websocket.Conn, tlsState *tls.ConnectionState, cfg WebSocketConfig, onMessage MessageHandler, onClosed DisconnectedHandler, onError ErrorHandler) *webSocket { + if conn == nil { + panic("cannot create websocket with nil connection") } - server.upgrader.Subprotocols = append(server.upgrader.Subprotocols, subProto) -} - -func (server *Server) SetBasicAuthHandler(handler func(username string, password string) bool) { - server.basicAuthHandler = handler -} - -func (server *Server) SetCheckOriginHandler(handler func(r *http.Request) bool) { - server.upgrader.CheckOrigin = handler -} - -func (server *Server) error(err error) { - log.Error(err) - if server.errC != nil { - server.errC <- err + w := &webSocket{ + id: id, + connection: conn, + mutex: sync.RWMutex{}, + tlsConnectionState: tlsState, + outQueue: make(chan message, 2), + pingC: make(chan []byte, 1), + closeC: make(chan websocket.CloseError, 1), + forceCloseC: make(chan error, 1), + onClosed: onClosed, + onError: onError, + onMessage: onMessage, } + w.updateConfig(cfg) + return w } -func (server *Server) Errors() <-chan error { - if server.errC == nil { - server.errC = make(chan error, 1) - } - return server.errC +// Retrieves the unique Identifier of the websocket (typically, the URL suffix). +func (w *webSocket) ID() string { + return w.id } -func (server *Server) Addr() *net.TCPAddr { - return server.addr +// Returns the address of the remote peer. +func (w *webSocket) RemoteAddr() net.Addr { + return w.connection.RemoteAddr() } -func (server *Server) Connections(websocketId string) *WebSocket { - server.connMutex.RLock() - defer server.connMutex.RUnlock() - return server.connections[websocketId] +// Returns the TLS connection state of the connection, if any. +func (w *webSocket) TLSConnectionState() *tls.ConnectionState { + return w.tlsConnectionState } -func (server *Server) AddHttpHandler(listenPath string, handler func(w http.ResponseWriter, r *http.Request)) { - server.httpHandler.HandleFunc(listenPath, handler) +func (w *webSocket) IsConnected() bool { + w.mutex.RLock() + defer w.mutex.RUnlock() + return w.connection != nil } -func (server *Server) Start(port int, listenPath string) { - server.connMutex.Lock() - server.connections = make(map[string]*WebSocket) - server.connMutex.Unlock() - - if server.httpServer == nil { - server.httpServer = &http.Server{} - } - - addr := fmt.Sprintf(":%v", port) - server.httpServer.Addr = addr - - server.AddHttpHandler(listenPath, func(w http.ResponseWriter, r *http.Request) { - server.wsHandler(w, r) - }) - server.httpServer.Handler = server.httpHandler - - ln, err := net.Listen("tcp", addr) - if err != nil { - server.error(fmt.Errorf("failed to listen: %w", err)) - return - } - - server.addr = ln.Addr().(*net.TCPAddr) - - defer ln.Close() - - log.Infof("listening on tcp network %v", addr) - server.httpServer.RegisterOnShutdown(server.stopConnections) - if server.tlsCertificatePath != "" && server.tlsCertificateKey != "" { - err = server.httpServer.ServeTLS(ln, server.tlsCertificatePath, server.tlsCertificateKey) - } else { - err = server.httpServer.Serve(ln) - } - - if err != http.ErrServerClosed { - server.error(fmt.Errorf("failed to listen: %w", err)) - } +func (w *webSocket) Write(data []byte) error { + return w.WriteManual(websocket.TextMessage, data) } -func (server *Server) Stop() { - log.Info("stopping websocket server") - err := server.httpServer.Shutdown(context.TODO()) - if err != nil { - server.error(fmt.Errorf("shutdown failed: %w", err)) +func (w *webSocket) WriteManual(messageTyp int, data []byte) error { + msg := message{ + typ: messageTyp, + data: data, } - - if server.errC != nil { - close(server.errC) - server.errC = nil + w.mutex.RLock() + defer w.mutex.RUnlock() + if w.connection == nil { + return fmt.Errorf("cannot write to closed connection %s", w.id) } + w.outQueue <- msg + return nil } -func (server *Server) StopConnection(id string, closeError websocket.CloseError) error { - server.connMutex.RLock() - ws, ok := server.connections[id] - server.connMutex.RUnlock() - - if !ok { - return fmt.Errorf("couldn't stop websocket connection. No connection with id %s is open", id) +func (w *webSocket) Close(closeError websocket.CloseError) error { + w.mutex.RLock() + defer w.mutex.RUnlock() + if w.connection == nil { + return fmt.Errorf("cannot close already closed connection %s", w.id) } - log.Debugf("sending stop signal for websocket %s", ws.ID()) - ws.closeC <- closeError + w.closeC <- closeError return nil } -func (server *Server) stopConnections() { - server.connMutex.RLock() - defer server.connMutex.RUnlock() - for _, conn := range server.connections { - conn.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""} +func (w *webSocket) updateConfig(cfg WebSocketConfig) { + w.mutex.Lock() + defer w.mutex.Unlock() + w.cfg = cfg + // Update logger + if cfg.Logger != nil { + w.log = cfg.Logger + } else { + w.log = log } + // Update ping pong logic + w.initPingPong() } -func (server *Server) Write(webSocketId string, data []byte) error { - server.connMutex.RLock() - defer server.connMutex.RUnlock() - ws, ok := server.connections[webSocketId] - if !ok { - return fmt.Errorf("couldn't write to websocket. No socket with id %v is open", webSocketId) +func (w *webSocket) getReadTimeout() time.Time { + var wait time.Duration + // Prefer ping config, then read wait, then no timeout + if w.cfg.PingConfig != nil && w.cfg.PingConfig.PongWait > 0 { + wait = w.cfg.PingConfig.PongWait + } else if w.cfg.ReadWait > 0 { + wait = w.cfg.ReadWait + } else { + // No timeout configured + return time.Time{} } - log.Debugf("queuing data for websocket %s", webSocketId) - ws.outQueue <- data - return nil + return time.Now().Add(wait) } -func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) { - responseHeader := http.Header{} - url := r.URL - id := path.Base(url.Path) - log.Debugf("handling new connection for %s from %s", id, r.RemoteAddr) - // Negotiate sub-protocol - clientSubprotocols := websocket.Subprotocols(r) - negotiatedSuprotocol := "" -out: - for _, requestedProto := range clientSubprotocols { - if len(server.upgrader.Subprotocols) == 0 { - // All subProtocols are accepted, pick first - negotiatedSuprotocol = requestedProto - break - } - // Check if requested suprotocol is supported by server - for _, supportedProto := range server.upgrader.Subprotocols { - if requestedProto == supportedProto { - negotiatedSuprotocol = requestedProto - break out - } - } - } - if negotiatedSuprotocol != "" { - responseHeader.Add("Sec-WebSocket-Protocol", negotiatedSuprotocol) - } - // Handle client authentication - if server.basicAuthHandler != nil { - username, password, ok := r.BasicAuth() - if ok { - ok = server.basicAuthHandler(username, password) - } - if !ok { - server.error(fmt.Errorf("basic auth failed: credentials invalid")) - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } +func (w *webSocket) initPingPong() { + conn := w.connection + if w.cfg.ReadWait > 0 { + // Expect pings, reply with pongs + conn.SetPingHandler(w.onPing) + } else { + conn.SetPingHandler(nil) } - - if server.checkClientHandler != nil { - ok := server.checkClientHandler(id, r) - if !ok { - server.error(fmt.Errorf("client validation: invalid client")) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } + if w.cfg.PingConfig != nil { + // Optionally send pings, expect pongs + conn.SetPongHandler(w.onPong) + } else { + conn.SetPongHandler(nil) } +} - // Upgrade websocket - conn, err := server.upgrader.Upgrade(w, r, responseHeader) - if err != nil { - server.error(fmt.Errorf("upgrade failed: %w", err)) - return - } +func (w *webSocket) onPing(appData string) error { + conn := w.connection + w.log.Debugf("ping received from %s: %s", w.id, appData) + // Schedule pong message via dedicated channel + w.pingC <- []byte(appData) + w.log.Debugf("pong scheduled for %s", w.id) + // Reset read interval after receiving a ping + return conn.SetReadDeadline(w.getReadTimeout()) +} - // The id of the charge point is the final path element - ws := WebSocket{ - connection: conn, - id: id, - outQueue: make(chan []byte, 1), - closeC: make(chan websocket.CloseError, 1), - forceCloseC: make(chan error, 1), - pingMessage: make(chan []byte, 1), - tlsConnectionState: r.TLS, - } - log.Debugf("upgraded websocket connection for %s from %s", id, conn.RemoteAddr().String()) - // If unsupported subprotocol, terminate the connection immediately - if negotiatedSuprotocol == "" { - server.error(fmt.Errorf("unsupported subprotocols %v for new client %v (%v)", clientSubprotocols, id, r.RemoteAddr)) - _ = conn.WriteControl(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseProtocolError, "invalid or unsupported subprotocol"), - time.Now().Add(server.timeoutConfig.WriteWait)) - _ = conn.Close() - return - } - // Check whether client exists - server.connMutex.Lock() - // There is already a connection with the same ID. Close the new one immediately with a PolicyViolation. - if _, exists := server.connections[id]; exists { - server.connMutex.Unlock() - server.error(fmt.Errorf("client %s already exists, closing duplicate client", id)) - _ = conn.WriteControl(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"), - time.Now().Add(server.timeoutConfig.WriteWait)) - _ = conn.Close() - return - } - // Add new client - server.connections[ws.id] = &ws - server.connMutex.Unlock() - // Read and write routines are started in separate goroutines and function will return immediately - go server.writePump(&ws) - go server.readPump(&ws) - if server.newClientHandler != nil { - var channel Channel = &ws - server.newClientHandler(channel) - } +func (w *webSocket) onPong(appData string) error { + conn := w.connection + w.log.Debugf("pong received from %s: %s", w.id, appData) + // Reset read interval after receiving a pong + return conn.SetReadDeadline(w.getReadTimeout()) } -func (server *Server) getReadTimeout() time.Time { - if server.timeoutConfig.PingWait == 0 { - return time.Time{} +func (w *webSocket) cleanup(err error) { + w.mutex.Lock() + // Properly close the connection + if e := w.connection.Close(); e != nil { + log.Errorf("failed to close connection for %s: %v", w.id, e) } - return time.Now().Add(server.timeoutConfig.PingWait) + w.connection = nil + close(w.outQueue) + close(w.pingC) + close(w.closeC) + close(w.forceCloseC) + w.mutex.Unlock() + // Invoke callback to notify the websocket was closed. + // If err is not nil, the disconnect is considered forced (i.e. not user-initiated). + w.onClosed(w, err) } -func (server *Server) readPump(ws *WebSocket) { - conn := ws.connection +func (w *webSocket) run() { + go w.readPump() + go w.writePump() +} - conn.SetPingHandler(func(appData string) error { - log.Debugf("ping received from %s", ws.ID()) - ws.pingMessage <- []byte(appData) - err := conn.SetReadDeadline(server.getReadTimeout()) - return err - }) - _ = conn.SetReadDeadline(server.getReadTimeout()) +// The readPump is a dedicated routine that awaits the next incoming message up until a deadline. +func (w *webSocket) readPump() { + w.mutex.RLock() + conn := w.connection + w.mutex.RUnlock() + if conn == nil { + err := fmt.Errorf("readPump started for %s with nil connection", w.id) + w.onError(w, err) + return + } + _ = conn.SetReadDeadline(w.getReadTimeout()) for { - _, message, err := conn.ReadMessage() + _, msg, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - server.error(fmt.Errorf("read failed unexpectedly for %s: %w", ws.ID(), err)) + w.onError(w, fmt.Errorf("read failed unexpectedly for %s: %w", w.id, err)) + } + // Verify whether the disconnect was already dealt with + w.mutex.RLock() + if w.connection == nil { + // Connection cleaned up, read simply got notified of the close -> ignore + w.log.Debugf("readPump stopped for %s due to closed connection", w.id) + w.mutex.RUnlock() + return } - log.Debugf("handling read error for %s: %v", ws.ID(), err.Error()) // Notify writePump of error. Force close will be handled there - ws.forceCloseC <- err + w.log.Debugf("handling read error for %s: %v", w.id, err.Error()) + w.forceCloseC <- err + w.mutex.RUnlock() return } - if server.messageHandler != nil { - var channel Channel = ws - err = server.messageHandler(channel, message) - if err != nil { - server.error(fmt.Errorf("handling failed for %s: %w", ws.ID(), err)) - continue - } + // Forward message to handler. + // Errors during the handling don't interrupt the websocket routine but will be reported. + err = w.onMessage(w, msg) + if err != nil { + w.onError(w, err) } - _ = conn.SetReadDeadline(server.getReadTimeout()) + _ = conn.SetReadDeadline(w.getReadTimeout()) } } -func (server *Server) writePump(ws *WebSocket) { - conn := ws.connection +// All actions and events are handled within this centralized control flow function. +func (w *webSocket) writePump() { + conn := w.connection + ticker := newOptTicker(w.cfg.PingConfig) + + closure := func(err error) { + ticker.Stop() + w.cleanup(err) + } for { select { - case data, ok := <-ws.outQueue: - _ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait)) - if !ok { - // Unexpected closed queue, should never happen - server.error(fmt.Errorf("output queue for socket %v was closed, forcefully closing", ws.id)) - // Don't invoke cleanup - return - } - // Send data - err := conn.WriteMessage(websocket.TextMessage, data) + case <-ticker.T(): + // Send periodic ping + _ = conn.SetWriteDeadline(time.Now().Add(w.cfg.WriteWait)) + err := conn.WriteMessage(websocket.PingMessage, []byte{}) if err != nil { - server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err)) + w.onError(w, fmt.Errorf("failed to send ping message for %s: %w", w.id, err)) // Invoking cleanup, as socket was forcefully closed - server.cleanupConnection(ws) + closure(err) return } - log.Debugf("written %d bytes to %s", len(data), ws.ID()) - case ping := <-ws.pingMessage: - _ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait)) + log.Debugf("ping sent for %s", w.id) + case ping := <-w.pingC: + // Reply with pong message + _ = conn.SetWriteDeadline(time.Now().Add(w.cfg.WriteWait)) err := conn.WriteMessage(websocket.PongMessage, ping) if err != nil { - server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err)) + w.onError(w, fmt.Errorf("failed to send pong message %s: %w", w.id, err)) // Invoking cleanup, as socket was forcefully closed - server.cleanupConnection(ws) + closure(err) return } - log.Debugf("pong sent to %s", ws.ID()) - case closeErr := <-ws.closeC: - log.Debugf("closing connection to %s", ws.ID()) - // Closing connection gracefully - if err := conn.WriteControl( - websocket.CloseMessage, - websocket.FormatCloseMessage(closeErr.Code, closeErr.Text), - time.Now().Add(server.timeoutConfig.WriteWait), - ); err != nil { - server.error(fmt.Errorf("failed to write close message for connection %s: %w", ws.id, err)) - } - // Invoking cleanup - server.cleanupConnection(ws) - return - case closed, ok := <-ws.forceCloseC: - if !ok || closed != nil { - // Connection was forcefully closed, invoke cleanup - log.Debugf("handling forced close signal for %s", ws.ID()) - server.cleanupConnection(ws) - } - return - } - } -} - -// Frees internal resources after a websocket connection was signaled to be closed. -// From this moment onwards, no new messages may be sent. -func (server *Server) cleanupConnection(ws *WebSocket) { - _ = ws.connection.Close() - server.connMutex.Lock() - close(ws.outQueue) - close(ws.closeC) - delete(server.connections, ws.id) - server.connMutex.Unlock() - log.Infof("closed connection to %s", ws.ID()) - if server.disconnectedHandler != nil { - server.disconnectedHandler(ws) - } -} - -// ---------------------- CLIENT ---------------------- - -// WsClient defines a websocket client, needed to connect to a websocket server. -// The offered API are of asynchronous nature, and each incoming message is handled using callbacks. -// -// To create a new ws client, use: -// -// client := NewClient() -// -// If you need a TLS ws client instead, use: -// -// certPool, err := x509.SystemCertPool() -// if err != nil { -// log.Fatal(err) -// } -// // You may add more trusted certificates to the pool before creating the TLSClientConfig -// client := NewTLSClient(&tls.Config{ -// RootCAs: certPool, -// }) -// -// To add additional dial options, use: -// -// client.AddOption(func(*websocket.Dialer) { -// // Your option ... -// )} -// -// To add basic HTTP authentication, use: -// -// client.SetBasicAuth("username","password") -// -// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method. -// -// Using Start and Stop you can respectively open/close a websocket to a websocket server. -// -// To receive incoming messages, you will need to set your own handler using SetMessageHandler. -// To write data on the open socket, simply call the Write function. -type WsClient interface { - // Starts the client and attempts to connect to the server on a specified URL. - // If the connection fails, an error is returned. - // - // For example: - // err := client.Start("ws://localhost:8887/ws/1234") - // - // The function returns immediately, after the connection has been established. - // Incoming messages are passed automatically to the callback function, so no explicit read operation is required. - // - // To stop a running client, call the Stop function. - Start(url string) error - // Starts the client and attempts to connect to the server on a specified URL. - // If the connection fails, it keeps retrying with Backoff strategy from TimeoutConfig. - // - // For example: - // client.StartWithRetries("ws://localhost:8887/ws/1234") - // - // The function returns only when the connection has been established. - // Incoming messages are passed automatically to the callback function, so no explicit read operation is required. - // - // To stop a running client, call the Stop function. - StartWithRetries(url string) - // Closes the output of the websocket Channel, effectively closing the connection to the server with a normal closure. - Stop() - // Errors returns a channel for error messages. If it doesn't exist it es created. - // The channel is closed by the client when stopped. - Errors() <-chan error - // Sets a callback function for all incoming messages. - SetMessageHandler(handler func(data []byte) error) - // Set custom timeout configuration parameters. If not passed, a default ClientTimeoutConfig struct will be used. - // - // This function must be called before connecting to the server, otherwise it may lead to unexpected behavior. - SetTimeoutConfig(config ClientTimeoutConfig) - // Sets a callback function for receiving notifications about an unexpected disconnection from the server. - // The callback is invoked even if the automatic reconnection mechanism is active. - // - // If the client was stopped using the Stop function, the callback will NOT be invoked. - SetDisconnectedHandler(handler func(err error)) - // Sets a callback function for receiving notifications whenever the connection to the server is re-established. - // Connections are re-established automatically thanks to the auto-reconnection mechanism. - // - // If set, the DisconnectedHandler will always be invoked before the Reconnected callback is invoked. - SetReconnectedHandler(handler func()) - // IsConnected Returns information about the current connection status. - // If the client is currently attempting to auto-reconnect to the server, the function returns false. - IsConnected() bool - // Sends a message to the server over the websocket. - // - // The data is queued and will be sent asynchronously in the background. - Write(data []byte) error - // Adds a websocket option to the client. - AddOption(option interface{}) - // SetRequestedSubProtocol will negotiate the specified sub-protocol during the websocket handshake. - // Internally this creates a dialer option and invokes the AddOption method on the client. - // - // Duplicates generated by invoking this method multiple times will be ignored. - SetRequestedSubProtocol(subProto string) - // SetBasicAuth adds basic authentication credentials, to use when connecting to the server. - // The credentials are automatically encoded in base64. - SetBasicAuth(username string, password string) - // SetHeaderValue sets a value on the HTTP header sent when opening a websocket connection to the server. - // - // The function overwrites previous header fields with the same key. - SetHeaderValue(key string, value string) -} - -// Client is the default implementation of a Websocket client. -// -// Use the NewClient or NewTLSClient functions to create a new client. -type Client struct { - webSocket WebSocket - url url.URL - messageHandler func(data []byte) error - dialOptions []func(*websocket.Dialer) - header http.Header - timeoutConfig ClientTimeoutConfig - connected bool - onDisconnected func(err error) - onReconnected func() - mutex sync.Mutex - errC chan error - reconnectC chan struct{} // used for signaling, that a reconnection attempt should be interrupted -} - -// Creates a new simple websocket client (the channel is not secured). -// -// Additional options may be added using the AddOption function. -// -// Basic authentication can be set using the SetBasicAuth function. -// -// By default, the client will not neogtiate any subprotocol. This value needs to be set via the -// respective SetRequestedSubProtocol method. -func NewClient() *Client { - return &Client{ - dialOptions: []func(*websocket.Dialer){}, - timeoutConfig: NewClientTimeoutConfig(), - header: http.Header{}, - } -} - -// NewTLSClient creates a new secure websocket client. If supported by the server, the websocket channel will use TLS. -// -// Additional options may be added using the AddOption function. -// Basic authentication can be set using the SetBasicAuth function. -// -// To set a client certificate, you may do: -// -// certificate, _ := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) -// clientCertificates := []tls.Certificate{certificate} -// client := ws.NewTLSClient(&tls.Config{ -// RootCAs: certPool, -// Certificates: clientCertificates, -// }) -// -// You can set any other TLS option within the same constructor as well. -// For example, if you wish to test connecting to a server having a -// self-signed certificate (do not use in production!), pass: -// -// InsecureSkipVerify: true -func NewTLSClient(tlsConfig *tls.Config) *Client { - client := &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig(), header: http.Header{}} - client.dialOptions = append(client.dialOptions, func(dialer *websocket.Dialer) { - dialer.TLSClientConfig = tlsConfig - }) - return client -} - -func (client *Client) SetMessageHandler(handler func(data []byte) error) { - client.messageHandler = handler -} - -func (client *Client) SetTimeoutConfig(config ClientTimeoutConfig) { - client.timeoutConfig = config -} - -func (client *Client) SetDisconnectedHandler(handler func(err error)) { - client.onDisconnected = handler -} - -func (client *Client) SetReconnectedHandler(handler func()) { - client.onReconnected = handler -} - -func (client *Client) AddOption(option interface{}) { - dialOption, ok := option.(func(*websocket.Dialer)) - if ok { - client.dialOptions = append(client.dialOptions, dialOption) - } -} - -func (client *Client) SetRequestedSubProtocol(subProto string) { - opt := func(dialer *websocket.Dialer) { - alreadyExists := false - for _, proto := range dialer.Subprotocols { - if proto == subProto { - alreadyExists = true - break + log.Debugf("pong sent for %s: %s", w.id, string(ping)) + case msg, ok := <-w.outQueue: + // New data needs to be written out (also invoked for pong messages) + if !ok { + // Unexpected closed queue, should never happen. + // Don't invoke any cleanup but just exit routine. + w.onError(w, fmt.Errorf("output queue for socket %v was closed, ignoring and existing", w.id)) + return } - } - if !alreadyExists { - dialer.Subprotocols = append(dialer.Subprotocols, subProto) - } - } - client.AddOption(opt) -} - -func (client *Client) SetBasicAuth(username string, password string) { - client.header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(username+":"+password))) -} - -func (client *Client) SetHeaderValue(key string, value string) { - client.header.Set(key, value) -} - -func (client *Client) getReadTimeout() time.Time { - if client.timeoutConfig.PongWait == 0 { - return time.Time{} - } - return time.Now().Add(client.timeoutConfig.PongWait) -} - -func (client *Client) writePump() { - ticker := time.NewTicker(client.timeoutConfig.PingPeriod) - conn := client.webSocket.connection - // Closure function correctly closes the current connection - closure := func(err error) { - ticker.Stop() - client.cleanup() - // Invoke callback - if client.onDisconnected != nil { - client.onDisconnected(err) - } - } - - for { - select { - case data := <-client.webSocket.outQueue: // Send data - log.Debugf("sending data") - _ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait)) - err := conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Now().Add(w.cfg.WriteWait)) + err := conn.WriteMessage(msg.typ, msg.data) if err != nil { - client.error(fmt.Errorf("write failed: %w", err)) - closure(err) - client.handleReconnection() - return - } - log.Debugf("written %d bytes", len(data)) - case <-ticker.C: - // Send periodic ping - _ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait)) - if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { - client.error(fmt.Errorf("failed to send ping message: %w", err)) + w.onError(w, fmt.Errorf("write failed for %s: %w", w.id, err)) + // Invoking cleanup, as socket was forcefully closed closure(err) - client.handleReconnection() return } - log.Debugf("ping sent") - case closeErr := <-client.webSocket.closeC: - log.Debugf("closing connection") - // Closing connection gracefully - if err := conn.WriteControl( + log.Debugf("written %d bytes to %s", len(msg.data), w.id) + case closeErr := <-w.closeC: + // webSocket is being gracefully closed by user command + w.log.Debugf("closing connection for %s: %d - %s", w.id, closeErr.Code, closeErr.Text) + // Send explicit close message + err := conn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(closeErr.Code, closeErr.Text), - time.Now().Add(client.timeoutConfig.WriteWait), - ); err != nil { - client.error(fmt.Errorf("failed to write close message: %w", err)) - } - // Disconnected by user command. Not calling auto-reconnect. - // Passing nil will also not call onDisconnected. - closure(nil) - return - case closed, ok := <-client.webSocket.forceCloseC: - log.Debugf("handling forced close signal") - // Read pump sent a forceClose signal (reading failed -> aborting the connection) - if !ok || closed != nil { - closure(closed) - client.handleReconnection() - return - } - } - } -} - -func (client *Client) readPump() { - conn := client.webSocket.connection - _ = conn.SetReadDeadline(client.getReadTimeout()) - conn.SetPongHandler(func(string) error { - log.Debugf("pong received") - return conn.SetReadDeadline(client.getReadTimeout()) - }) - for { - _, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - client.error(fmt.Errorf("read failed: %w", err)) - } - // Notify writePump of error. Forced close will be handled there - client.webSocket.forceCloseC <- err - return - } - - log.Debugf("received %v bytes", len(message)) - if client.messageHandler != nil { - err = client.messageHandler(message) + time.Now().Add(w.cfg.WriteWait)) if err != nil { - client.error(fmt.Errorf("handle failed: %w", err)) - continue + // At this point the connection is considered to be forcefully closed, + // but we still continue with the intended flow. + w.onError(w, fmt.Errorf("failed to write close message for connection %s: %w", w.id, err)) } - } - } -} - -// Frees internal resources after a websocket connection was signaled to be closed. -// From this moment onwards, no new messages may be sent. -func (client *Client) cleanup() { - client.setConnected(false) - ws := client.webSocket - _ = ws.connection.Close() - client.mutex.Lock() - defer client.mutex.Unlock() - close(ws.outQueue) - close(ws.closeC) -} - -func (client *Client) handleReconnection() { - log.Info("started automatic reconnection handler") - delay := client.timeoutConfig.RetryBackOffWaitMinimum + time.Duration(rand.Intn(client.timeoutConfig.RetryBackOffRandomRange+1))*time.Second - reconnectionAttempts := 1 - for { - // Wait before reconnecting - select { - case <-time.After(delay): - case <-client.reconnectC: + // Invoking cleanup, but signal that this is an intended operation, + // preventing automatic reconnection attempts. + closure(nil) return - } - - log.Info("reconnecting... attempt", reconnectionAttempts) - err := client.Start(client.url.String()) - if err == nil { - // Re-connection was successful - log.Info("reconnected successfully to server") - if client.onReconnected != nil { - client.onReconnected() + case closed, _ := <-w.forceCloseC: + if closed == nil { + closed = fmt.Errorf("websocket read channel closed abruptly") } + // webSocket is being forcefully closed, triggered by readPump encountering a failed read. + log.Debugf("handling forced close signal for %s, caused by: %v", w.id, closed.Error()) + // Connection was forcefully closed, invoke cleanup + closure(closed) return } - client.error(fmt.Errorf("reconnection failed: %w", err)) - - if reconnectionAttempts < client.timeoutConfig.RetryBackOffRepeatTimes { - // Re-connection failed, double the delay - delay *= 2 - delay += time.Duration(rand.Intn(client.timeoutConfig.RetryBackOffRandomRange+1)) * time.Second - } - reconnectionAttempts += 1 - } -} - -func (client *Client) setConnected(connected bool) { - client.mutex.Lock() - defer client.mutex.Unlock() - client.connected = connected -} - -func (client *Client) IsConnected() bool { - client.mutex.Lock() - defer client.mutex.Unlock() - return client.connected -} - -func (client *Client) Write(data []byte) error { - if !client.IsConnected() { - return fmt.Errorf("client is currently not connected, cannot send data") - } - log.Debugf("queuing data for server") - client.webSocket.outQueue <- data - return nil -} - -func (client *Client) StartWithRetries(urlStr string) { - err := client.Start(urlStr) - if err != nil { - log.Info("Connection error:", err) - client.handleReconnection() - } -} - -func (client *Client) Start(urlStr string) error { - url, err := url.Parse(urlStr) - client.url = *url - if err != nil { - return err - } - - dialer := websocket.Dialer{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - HandshakeTimeout: client.timeoutConfig.HandshakeTimeout, - Subprotocols: []string{}, - } - for _, option := range client.dialOptions { - option(&dialer) - } - // Connect - log.Info("connecting to server") - ws, resp, err := dialer.Dial(urlStr, client.header) - if err != nil { - if resp != nil { - httpError := HttpConnectionError{Message: err.Error(), HttpStatus: resp.Status, HttpCode: resp.StatusCode} - // Parse http response details - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if body != nil { - httpError.Details = string(body) - } - err = httpError - } - return err } - - // The id of the charge point is the final path element - id := path.Base(url.Path) - - client.webSocket = WebSocket{ - connection: ws, - id: id, - outQueue: make(chan []byte, 1), - closeC: make(chan websocket.CloseError, 1), - forceCloseC: make(chan error, 1), - tlsConnectionState: resp.TLS, - } - log.Infof("connected to server as %s", id) - client.reconnectC = make(chan struct{}) - client.setConnected(true) - // Start reader and write routine - go client.writePump() - go client.readPump() - return nil -} - -func (client *Client) Stop() { - log.Infof("closing connection to server") - client.mutex.Lock() - if client.connected { - client.connected = false - // Send signal for gracefully shutting down the connection - select { - case client.webSocket.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}: - default: - } - } - client.mutex.Unlock() - // Notify reconnection goroutine to stop (if any) - if client.reconnectC != nil { - close(client.reconnectC) - } - if client.errC != nil { - close(client.errC) - client.errC = nil - } - // Wait for connection to actually close } -func (client *Client) error(err error) { - log.Error(err) - if client.errC != nil { - client.errC <- err - } +// HttpConnectionError is a websocket-specific error propagated to the upper +// layers when opening a websocket fails. +type HttpConnectionError struct { + Message string + HttpStatus string + HttpCode int + Details string } -func (client *Client) Errors() <-chan error { - if client.errC == nil { - client.errC = make(chan error, 1) - } - return client.errC +func (e HttpConnectionError) Error() string { + return fmt.Sprintf("%v, http status: %v", e.Message, e.HttpStatus) } func init() { diff --git a/ws/websocket_test.go b/ws/websocket_test.go index 399f7259..f8c8bcec 100644 --- a/ws/websocket_test.go +++ b/ws/websocket_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "fmt" "math/big" "net" @@ -17,11 +18,14 @@ import ( "os" "path" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -34,319 +38,447 @@ const ( defaultSubProtocol = "ocpp1.6" ) -func newWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Server { +func newWebsocketServer(t *testing.T, onMessage func(data []byte) ([]byte, error)) *server { wsServer := NewServer() - wsServer.SetMessageHandler(func(ws Channel, data []byte) error { + innerS, ok := wsServer.(*server) + require.True(t, ok) + innerS.SetMessageHandler(func(ws Channel, data []byte) error { assert.NotNil(t, ws) assert.NotNil(t, data) if onMessage != nil { response, err := onMessage(data) assert.Nil(t, err) if response != nil { - err = wsServer.Write(ws.ID(), data) + err = innerS.Write(ws.ID(), data) assert.Nil(t, err) } } return nil }) - return wsServer + return innerS } -func newWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *Client { +func newWebsocketClient(t *testing.T, onMessage func(data []byte) ([]byte, error)) *client { wsClient := NewClient() - wsClient.SetRequestedSubProtocol(defaultSubProtocol) - wsClient.SetMessageHandler(func(data []byte) error { + innerC, ok := wsClient.(*client) + require.True(t, ok) + innerC.SetRequestedSubProtocol(defaultSubProtocol) + innerC.SetMessageHandler(func(data []byte) error { assert.NotNil(t, data) if onMessage != nil { response, err := onMessage(data) assert.Nil(t, err) if response != nil { - err = wsClient.Write(data) + err = innerC.Write(data) assert.Nil(t, err) } } return nil }) - return wsClient + return innerC } -func TestWebsocketSetConnected(t *testing.T) { - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - return nil, nil - }) - assert.False(t, wsClient.IsConnected()) - wsClient.setConnected(true) - assert.True(t, wsClient.IsConnected()) - wsClient.setConnected(false) - assert.False(t, wsClient.IsConnected()) +type WebSocketSuite struct { + suite.Suite + client *client + server *server } -func TestWebsocketGetReadTimeout(t *testing.T) { - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - return nil, nil +func (s *WebSocketSuite) SetupTest() { + s.server = newWebsocketServer(s.T(), nil) + s.client = newWebsocketClient(s.T(), nil) +} + +func (s *WebSocketSuite) TearDownTest() { + if s.client != nil { + s.client.Stop() + } + if s.server != nil { + s.server.Stop() + } +} + +func (s *WebSocketSuite) TestPingTicker() { + defaultPeriod := 1 * time.Millisecond + testTable := []struct { + name string + active bool + pingPeriod time.Duration + expectedTickerNil bool + }{ + { + "real ticker", + true, + defaultPeriod, + false, + }, + { + "dummy ticker", + false, + defaultPeriod, + true, + }, + { + "dummy ticker due to invalid period", + true, + 0, + true, + }, + } + for _, tc := range testTable { + var pc *PingConfig + if tc.active { + pc = &PingConfig{ + PingPeriod: tc.pingPeriod, + } + } + t := newOptTicker(pc) + if tc.expectedTickerNil { + s.Nil(t.ticker, tc.name) + s.NotNil(t.c, tc.name) + } else { + s.NotNil(t.ticker, tc.name) + s.Nil(t.c, tc.name) + } + // Test retrieving channel + c := t.T() + s.NotNil(c) + // Test waiting for tick + select { + case <-c: + if tc.expectedTickerNil { + s.Fail("unexpected tick from nil ticker", tc.name) + } + case <-time.After(2 * defaultPeriod): + if !tc.expectedTickerNil { + s.Fail("unexpected timeout from real ticker", tc.name) + } + } + // Test waiting for tick after stop + t.Stop() + select { + case <-c: + s.Fail("unexpected tick from stopped ticker", tc.name) + case <-time.After(2 * defaultPeriod): + break + } + } +} + +func (s *WebSocketSuite) TestWebsocketConnectionState() { + s.False(s.client.IsConnected()) + closeC := make(chan struct{}, 1) + s.client.SetMessageHandler(func(data []byte) error { + s.Fail("unexpected message") + return nil }) - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - return nil, nil + s.client.SetDisconnectedHandler(func(err error) { + closeC <- struct{}{} }) - // Test server timeout for default settings + // Simulate connection + go s.server.Start(serverPort, serverPath) + time.Sleep(50 * time.Millisecond) + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + err := s.client.Start(u.String()) + s.NoError(err) + // Check connection state on internal web socket + ws := s.client.webSocket + s.NotNil(ws) + s.True(ws.IsConnected()) + // Close connection + err = ws.Close(websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}) + s.NoError(err) + select { + case <-closeC: + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for connection to close") + } + s.False(ws.IsConnected()) +} + +func (s *WebSocketSuite) TestWebsocketGetReadTimeout() { + // Create default timeout settings and handlers serverTimeoutConfig := NewServerTimeoutConfig() - wsServer.SetTimeoutConfig(serverTimeoutConfig) + s.server.SetTimeoutConfig(serverTimeoutConfig) + ctrlC := make(chan struct{}, 1) + s.server.SetNewClientHandler(func(ws Channel) { + ctrlC <- struct{}{} + }) + // Simulate connection to initialize a websocket + go s.server.Start(serverPort, serverPath) + time.Sleep(50 * time.Millisecond) + host := fmt.Sprintf("localhost:%v", serverPort) + u := url.URL{Scheme: "ws", Host: host, Path: testPath} + err := s.client.Start(u.String()) + s.NoError(err) + // Wait for connection to be established + select { + case <-ctrlC: + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for connection to establish") + } + // Test server timeout for default settings + serverW, ok := s.server.connections["testws"] + s.True(ok) now := time.Now() - timeout := wsServer.getReadTimeout() - assert.GreaterOrEqual(t, timeout.Unix(), now.Add(serverTimeoutConfig.PingWait).Unix()) + timeout := serverW.getReadTimeout() + s.GreaterOrEqual(timeout.Unix(), now.Add(s.server.timeoutConfig.PingWait).Unix()) // Test server timeout for zero setting - serverTimeoutConfig.PingWait = 0 - wsServer.SetTimeoutConfig(serverTimeoutConfig) - timeout = wsServer.getReadTimeout() - assert.Equal(t, time.Time{}, timeout) + cfg := serverW.cfg + cfg.ReadWait = 0 + serverW.updateConfig(cfg) + timeout = serverW.getReadTimeout() + s.Equal(time.Time{}, timeout) // Test client timeout for default settings - clientTimeoutConfig := NewClientTimeoutConfig() - wsClient.SetTimeoutConfig(clientTimeoutConfig) + clientW := s.client.webSocket + s.NotNil(clientW) now = time.Now() - timeout = wsClient.getReadTimeout() - assert.GreaterOrEqual(t, timeout.Unix(), now.Add(clientTimeoutConfig.PongWait).Unix()) + timeout = clientW.getReadTimeout() + s.GreaterOrEqual(timeout.Unix(), now.Add(s.client.timeoutConfig.PongWait).Unix()) // Test client timeout for zero setting - clientTimeoutConfig.PongWait = 0 - wsClient.SetTimeoutConfig(clientTimeoutConfig) - timeout = wsClient.getReadTimeout() - assert.Equal(t, time.Time{}, timeout) + cfg = clientW.cfg + cfg.PingConfig.PongWait = 0 + cfg.ReadWait = 0 + clientW.updateConfig(cfg) + timeout = clientW.getReadTimeout() + s.Equal(time.Time{}, timeout) } -func TestWebsocketEcho(t *testing.T) { - message := []byte("Hello WebSocket!") - triggerC := make(chan bool, 1) - done := make(chan bool, 1) - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - assert.True(t, bytes.Equal(message, data)) - // Message received, notifying flow routine - triggerC <- true +func (s *WebSocketSuite) TestWebsocketEcho() { + msg := []byte("Hello webSocket!") + triggerC := make(chan struct{}, 1) + done := make(chan struct{}, 1) + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { + s.True(bytes.Equal(msg, data)) + // Echo reply received, notifying flow routine + triggerC <- struct{}{} return data, nil }) - wsServer.SetNewClientHandler(func(ws Channel) { + s.server.SetNewClientHandler(func(ws Channel) { tlsState := ws.TLSConnectionState() - assert.Nil(t, tlsState) - }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { - // Connection closed, completing test - done <- true + s.Nil(tlsState) }) - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - assert.True(t, bytes.Equal(message, data)) + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + s.True(bytes.Equal(msg, data)) // Echo response received, notifying flow routine - triggerC <- true + done <- struct{}{} return nil, nil }) // Start server - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) // Start flow routine go func() { - // Wait for messages to be exchanged, then close connection + // Wait for messages to be exchanged in a dedicate routine. + // Will reply to client. sig := <-triggerC - assert.True(t, sig) - err := wsServer.Write(path.Base(testPath), message) - require.Nil(t, err) + s.NotNil(sig) + err := s.server.Write(path.Base(testPath), msg) + s.NoError(err) sig = <-triggerC - assert.True(t, sig) - wsClient.Stop() + s.NotNil(sig) }() - time.Sleep(200 * time.Millisecond) - - // Test message + time.Sleep(100 * time.Millisecond) + // Test connection host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - require.NoError(t, err) - require.True(t, wsClient.IsConnected()) - err = wsClient.Write(message) - require.NoError(t, err) + err := s.client.Start(u.String()) + s.NoError(err) + s.True(s.client.IsConnected()) + // Test message + err = s.client.Write(msg) + s.NoError(err) // Wait for echo result - result := <-done - assert.True(t, result) - // Cleanup - wsServer.Stop() + select { + case result := <-done: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for echo result") + } } -func TestWebsocketBootRetries(t *testing.T) { - verifyConnection := func(client *Client, connected bool) { +func (s *WebSocketSuite) TestWebsocketBootRetries() { + verifyConnection := func(client *client, connected bool) { maxAttempts := 20 for i := 0; i <= maxAttempts; i++ { if client.IsConnected() != connected { - time.Sleep(time.Duration(2) * time.Second) + time.Sleep(200 * time.Millisecond) continue } } - assert.Equal(t, connected, client.IsConnected()) + s.Equal(connected, client.IsConnected()) } - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { return data, nil }) - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { return nil, nil }) + // Reduce timeout to make test faster + s.client.timeoutConfig.RetryBackOffWaitMinimum = 1 * time.Second + s.client.timeoutConfig.RetryBackOffRandomRange = 2 go func() { // Start websocket client host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - wsClient.StartWithRetries(u.String()) + s.client.StartWithRetries(u.String()) }() + // Initial connection attempt fails, as server isn't listening yet + s.False(s.client.IsConnected()) - assert.Equal(t, wsClient.IsConnected(), false) - - time.Sleep(time.Duration(3) * time.Second) - - go wsServer.Start(serverPort, serverPath) - verifyConnection(wsClient, true) + time.Sleep(500 * time.Millisecond) - wsServer.Stop() - verifyConnection(wsClient, false) + go s.server.Start(serverPort, serverPath) + verifyConnection(s.client, true) - wsServer.Stop() - wsClient.Stop() + s.server.Stop() + verifyConnection(s.client, false) } -func TestTLSWebsocketEcho(t *testing.T) { - message := []byte("Hello Secure WebSocket!") - triggerC := make(chan bool, 1) - done := make(chan bool, 1) - // Use NewTLSServer() when in different package - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - assert.True(t, bytes.Equal(message, data)) +func (s *WebSocketSuite) TestTLSWebsocketEcho() { + msg := []byte("Hello Secure webSocket!") + triggerC := make(chan struct{}, 1) + done := make(chan struct{}, 1) + // Use NewServer(WithServerTLSConfig(...)) when in different package + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { + s.True(bytes.Equal(msg, data)) // Message received, notifying flow routine - triggerC <- true + triggerC <- struct{}{} return data, nil }) - wsServer.SetNewClientHandler(func(ws Channel) { + s.server.SetNewClientHandler(func(ws Channel) { tlsState := ws.TLSConnectionState() - assert.NotNil(t, tlsState) + s.NotNil(tlsState) }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { // Connection closed, completing test - done <- true + done <- struct{}{} }) // Create self-signed TLS certificate + // TODO: use FiloSottile's lib for this certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" err := createTLSCertificate(certFilename, keyFilename, "localhost", nil, nil) - require.Nil(t, err) + s.NoError(err) defer os.Remove(certFilename) defer os.Remove(keyFilename) // Set self-signed TLS certificate - wsServer.tlsCertificatePath = certFilename - wsServer.tlsCertificateKey = keyFilename + s.server.tlsCertificatePath = certFilename + s.server.tlsCertificateKey = keyFilename // Create TLS client - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - assert.True(t, bytes.Equal(message, data)) + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + s.True(bytes.Equal(msg, data)) // Echo response received, notifying flow routine - triggerC <- true + done <- struct{}{} return nil, nil }) - wsClient.AddOption(func(dialer *websocket.Dialer) { + s.client.AddOption(func(dialer *websocket.Dialer) { certPool := x509.NewCertPool() data, err := os.ReadFile(certFilename) - assert.Nil(t, err) + s.NoError(err) ok := certPool.AppendCertsFromPEM(data) - assert.True(t, ok) + s.True(ok) dialer.TLSClientConfig = &tls.Config{ RootCAs: certPool, } }) // Start server - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) // Start flow routine go func() { // Wait for messages to be exchanged, then close connection sig := <-triggerC - assert.True(t, sig) - err := wsServer.Write(path.Base(testPath), message) - require.NoError(t, err) + s.NotNil(sig) + err = s.server.Write(path.Base(testPath), msg) + s.NoError(err) sig = <-triggerC - assert.True(t, sig) - wsClient.Stop() + s.NotNil(sig) }() - time.Sleep(200 * time.Millisecond) + time.Sleep(100 * time.Millisecond) - // Test message + // Test connection host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - require.NoError(t, err) - require.True(t, wsClient.IsConnected()) - err = wsClient.Write(message) - require.NoError(t, err) + err = s.client.Start(u.String()) + s.NoError(err) + s.True(s.client.IsConnected()) + // Test message + err = s.client.Write(msg) + s.NoError(err) // Wait for echo result - result := <-done - assert.True(t, result) - // Cleanup - wsServer.Stop() + select { + case result := <-done: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for echo result") + } } -func TestServerStartErrors(t *testing.T) { - triggerC := make(chan bool, 1) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { - triggerC <- true +func (s *WebSocketSuite) TestServerStartErrors() { + triggerC := make(chan struct{}, 1) + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { + triggerC <- struct{}{} }) // Make sure http server is initialized on start - wsServer.httpServer = nil + s.server.httpServer = nil // Listen for errors go func() { - err, ok := <-wsServer.Errors() - assert.True(t, ok) - assert.Error(t, err) - triggerC <- true + err, ok := <-s.server.Errors() + s.True(ok) + s.Error(err) + triggerC <- struct{}{} }() time.Sleep(100 * time.Millisecond) - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) time.Sleep(100 * time.Millisecond) // Starting server again throws error - wsServer.Start(serverPort, serverPath) + s.server.Start(serverPort, serverPath) r := <-triggerC - require.True(t, r) - wsServer.Stop() + s.NotNil(r) } -func TestClientDuplicateConnection(t *testing.T) { - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { +func (s *WebSocketSuite) TestClientDuplicateConnection() { + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { }) // Start server - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) time.Sleep(100 * time.Millisecond) // Connect client 1 - wsClient1 := newWebsocketClient(t, func(data []byte) ([]byte, error) { + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { return nil, nil }) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient1.Start(u.String()) - require.NoError(t, err) + err := s.client.Start(u.String()) + s.NoError(err) // Try to connect client 2 disconnectC := make(chan struct{}) - wsClient2 := newWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient2 := newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { return nil, nil }) wsClient2.SetDisconnectedHandler(func(err error) { - require.IsType(t, &websocket.CloseError{}, err) - wsErr, _ := err.(*websocket.CloseError) - assert.Equal(t, websocket.ClosePolicyViolation, wsErr.Code) - assert.Equal(t, "a connection with this ID already exists", wsErr.Text) + s.IsType(&websocket.CloseError{}, err) + var wsErr *websocket.CloseError + ok := errors.As(err, &wsErr) + s.True(ok) + s.Equal(websocket.ClosePolicyViolation, wsErr.Code) + s.Equal("a connection with this ID already exists", wsErr.Text) wsClient2.SetDisconnectedHandler(nil) disconnectC <- struct{}{} }) err = wsClient2.Start(u.String()) - require.NoError(t, err) + s.NoError(err) // Expect connection to be closed immediately _, ok := <-disconnectC - assert.True(t, ok) - // Cleanup - wsClient1.Stop() - wsServer.Stop() + s.True(ok) } -func TestServerStopConnection(t *testing.T) { +func (s *WebSocketSuite) TestServerStopConnection() { triggerC := make(chan struct{}, 1) disconnectedClientC := make(chan struct{}, 1) disconnectedServerC := make(chan struct{}, 1) @@ -354,376 +486,404 @@ func TestServerStopConnection(t *testing.T) { Code: websocket.CloseGoingAway, Text: "CloseClientConnection", } - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { + wsID := "testws" + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { triggerC <- struct{}{} }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { disconnectedServerC <- struct{}{} }) - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { return nil, nil }) - wsClient.SetDisconnectedHandler(func(err error) { - require.IsType(t, &closeError, err) - closeErr, _ := err.(*websocket.CloseError) - assert.Equal(t, closeError.Code, closeErr.Code) - assert.Equal(t, closeError.Text, closeErr.Text) + s.client.SetDisconnectedHandler(func(err error) { + s.IsType(&closeError, err) + var closeErr *websocket.CloseError + ok := errors.As(err, &closeErr) + s.True(ok) + s.Equal(closeError.Code, closeErr.Code) + s.Equal(closeError.Text, closeErr.Text) disconnectedClientC <- struct{}{} }) // Start server - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) time.Sleep(100 * time.Millisecond) + var c Channel + var ok bool + c, ok = s.server.GetChannel(wsID) + s.False(ok) + s.Nil(c) // Connect client host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - require.NoError(t, err) + err := s.client.Start(u.String()) + s.NoError(err) // Wait for client to connect - _, ok := <-triggerC - require.True(t, ok) + _, ok = <-triggerC + s.True(ok) + // Verify channel + c, ok = s.server.GetChannel(wsID) + s.True(ok) + s.NotNil(c) + s.Equal(wsID, c.ID()) + s.True(c.IsConnected()) // Close connection and wait for client to be closed - err = wsServer.StopConnection(path.Base(testPath), closeError) - require.NoError(t, err) + err = s.server.StopConnection(path.Base(testPath), closeError) + s.NoError(err) _, ok = <-disconnectedClientC - require.True(t, ok) + s.True(ok) _, ok = <-disconnectedServerC - require.True(t, ok) - assert.False(t, wsClient.IsConnected()) + s.True(ok) + s.False(s.client.IsConnected()) time.Sleep(100 * time.Millisecond) - assert.Empty(t, wsServer.connections) - // Client will attempt to reconnect under the hood, but test finishes before this can happen - // Cleanup - wsClient.Stop() - wsServer.Stop() + s.Empty(s.server.connections) + // client will attempt to reconnect under the hood, but test finishes before this can happen } -func TestWebsocketServerStopAllConnections(t *testing.T) { +func (s *WebSocketSuite) TestWebsocketServerStopAllConnections() { triggerC := make(chan struct{}, 1) numClients := 5 - disconnectedClientC := make(chan struct{}, numClients) disconnectedServerC := make(chan struct{}, 1) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { triggerC <- struct{}{} }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { disconnectedServerC <- struct{}{} }) // Start server - go wsServer.Start(serverPort, serverPath) + go s.server.Start(serverPort, serverPath) time.Sleep(100 * time.Millisecond) // Connect clients - clients := []WsClient{} + clients := []Client{} + wg := sync.WaitGroup{} host := fmt.Sprintf("localhost:%v", serverPort) for i := 0; i < numClients; i++ { - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { + wsClient := newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { return nil, nil }) wsClient.SetDisconnectedHandler(func(err error) { - require.IsType(t, &websocket.CloseError{}, err) - closeErr, _ := err.(*websocket.CloseError) - assert.Equal(t, websocket.CloseNormalClosure, closeErr.Code) - assert.Equal(t, "", closeErr.Text) - disconnectedClientC <- struct{}{} + s.IsType(&websocket.CloseError{}, err) + var closeErr *websocket.CloseError + ok := errors.As(err, &closeErr) + s.True(ok) + s.Equal(websocket.CloseNormalClosure, closeErr.Code) + s.Equal("", closeErr.Text) + wg.Done() }) u := url.URL{Scheme: "ws", Host: host, Path: fmt.Sprintf("%v-%v", testPath, i)} err := wsClient.Start(u.String()) - require.NoError(t, err) + s.NoError(err) clients = append(clients, wsClient) // Wait for client to connect _, ok := <-triggerC - require.True(t, ok) + s.True(ok) + wg.Add(1) } // Stop server and wait for clients to disconnect - wsServer.Stop() - for disconnects := 0; disconnects < numClients; disconnects++ { - _, ok := <-disconnectedClientC - require.True(t, ok) - _, ok = <-disconnectedServerC - require.True(t, ok) + s.server.Stop() + waitC := make(chan struct{}, 1) + go func() { + wg.Wait() + waitC <- struct{}{} + }() + select { + case <-waitC: + break + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for clients to disconnect") } - // Check disconnection status + // Double-check disconnection status for _, c := range clients { - assert.False(t, c.IsConnected()) - // Client will attempt to reconnect under the hood, but test finishes before this can happen + s.False(c.IsConnected()) + // client will attempt to reconnect under the hood, but test finishes before this can happen c.Stop() } time.Sleep(100 * time.Millisecond) - assert.Empty(t, wsServer.connections) + s.Empty(s.server.connections) } -func TestWebsocketClientConnectionBreak(t *testing.T) { - newClient := make(chan bool) - disconnected := make(chan bool) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { - newClient <- true +func (s *WebSocketSuite) TestWebsocketClientConnectionBreak() { + newClient := make(chan struct{}) + disconnected := make(chan struct{}) + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { + newClient <- struct{}{} }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { - disconnected <- true + s.server.SetDisconnectedClientHandler(func(ws Channel) { + disconnected <- struct{}{} }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Test - wsClient := newWebsocketClient(t, nil) + s.client = newWebsocketClient(s.T(), nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - // Wait for connection to be established, then break the connection + // Wait for connection to be established, then break the connection asynchronously go func() { - timer := time.NewTimer(1 * time.Second) - <-timer.C - err := wsClient.webSocket.connection.Close() - assert.Nil(t, err) + <-time.After(200 * time.Millisecond) + err := s.client.webSocket.connection.Close() + s.NoError(err) }() - err := wsClient.Start(u.String()) - assert.Nil(t, err) + // Connect and wait + err := s.client.Start(u.String()) + s.NoError(err) result := <-newClient - assert.True(t, result) - result = <-disconnected - assert.True(t, result) - // Cleanup - wsServer.Stop() + s.NotNil(result) + // Wait for internal disconnect + select { + case result = <-disconnected: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client disconnect") + } } -func TestWebsocketServerConnectionBreak(t *testing.T) { - disconnected := make(chan bool) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { - assert.NotNil(t, ws) - conn := wsServer.connections[ws.ID()] - assert.NotNil(t, conn) +func (s *WebSocketSuite) TestWebsocketServerConnectionBreak() { + disconnected := make(chan struct{}, 1) + s.server = newWebsocketServer(s.T(), nil) + s.server.SetNewClientHandler(func(ws Channel) { + s.NotNil(ws) + conn := s.server.connections[ws.ID()] + s.NotNil(conn) // Simulate connection closed as soon client is connected err := conn.connection.Close() - assert.Nil(t, err) + s.NoError(err) }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { - disconnected <- true + s.server.SetDisconnectedClientHandler(func(ws Channel) { + disconnected <- struct{}{} }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Test - wsClient := newWebsocketClient(t, nil) + s.client = newWebsocketClient(s.T(), nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - assert.Nil(t, err) - result := <-disconnected - assert.True(t, result) - // Cleanup - wsServer.Stop() + err := s.client.Start(u.String()) + s.NoError(err) + + select { + case result := <-disconnected: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for server disconnect") + } } -func TestValidBasicAuth(t *testing.T) { +func (s *WebSocketSuite) TestValidBasicAuth() { + var ok bool authUsername := "testUsername" authPassword := "testPassword" // Create self-signed TLS certificate + // TODO: replace with FiloSottile's lib certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" err := createTLSCertificate(certFilename, keyFilename, "localhost", nil, nil) - require.Nil(t, err) + s.NoError(err) defer os.Remove(certFilename) defer os.Remove(keyFilename) // Create TLS server with self-signed certificate - wsServer := NewTLSServer(certFilename, keyFilename, nil) + tlsServer := NewServer(WithServerTLSConfig(certFilename, keyFilename, nil)) + s.server, ok = tlsServer.(*server) + s.True(ok) // Add basic auth handler - wsServer.SetBasicAuthHandler(func(username string, password string) bool { - require.Equal(t, authUsername, username) - require.Equal(t, authPassword, password) + s.server.SetBasicAuthHandler(func(username string, password string) bool { + s.Equal(authUsername, username) + s.Equal(authPassword, password) return true }) - connected := make(chan bool) - wsServer.SetNewClientHandler(func(ws Channel) { - connected <- true + connected := make(chan struct{}, 1) + s.server.SetNewClientHandler(func(ws Channel) { + connected <- struct{}{} }) // Run server - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Create TLS client certPool := x509.NewCertPool() data, err := os.ReadFile(certFilename) - require.Nil(t, err) - ok := certPool.AppendCertsFromPEM(data) - require.True(t, ok) - wsClient := NewTLSClient(&tls.Config{ + s.NoError(err) + ok = certPool.AppendCertsFromPEM(data) + s.True(ok) + tlsClient := NewClient(WithClientTLSConfig(&tls.Config{ RootCAs: certPool, - }) - wsClient.SetRequestedSubProtocol(defaultSubProtocol) + })) + s.client, ok = tlsClient.(*client) + s.True(ok) + s.client.SetRequestedSubProtocol(defaultSubProtocol) // Add basic auth - wsClient.SetBasicAuth(authUsername, authPassword) + s.client.SetBasicAuth(authUsername, authPassword) // Test connection host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - require.Nil(t, err) + err = s.client.Start(u.String()) + s.NoError(err) result := <-connected - assert.True(t, result) - // Cleanup - wsClient.Stop() - wsServer.Stop() + s.NotNil(result) + s.True(s.client.IsConnected()) } -func TestInvalidBasicAuth(t *testing.T) { +func (s *WebSocketSuite) TestInvalidBasicAuth() { + var ok bool authUsername := "testUsername" authPassword := "testPassword" // Create self-signed TLS certificate + // TODO: replace with FiloSottile's lib certFilename := "/tmp/cert.pem" keyFilename := "/tmp/key.pem" err := createTLSCertificate(certFilename, keyFilename, "localhost", nil, nil) - require.Nil(t, err) + s.NoError(err) defer os.Remove(certFilename) defer os.Remove(keyFilename) // Create TLS server with self-signed certificate - wsServer := NewTLSServer(certFilename, keyFilename, nil) + tlsServer := NewServer(WithServerTLSConfig(certFilename, keyFilename, nil)) + s.server, ok = tlsServer.(*server) + s.True(ok) // Add basic auth handler - wsServer.SetBasicAuthHandler(func(username string, password string) bool { + s.server.SetBasicAuthHandler(func(username string, password string) bool { validCredentials := authUsername == username && authPassword == password - require.False(t, validCredentials) + s.False(validCredentials) return validCredentials }) - wsServer.SetNewClientHandler(func(ws Channel) { + s.server.SetNewClientHandler(func(ws Channel) { // Should never reach this - t.Fail() + s.Fail("no new connection should be received from client!") }) // Run server - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Create TLS client certPool := x509.NewCertPool() data, err := os.ReadFile(certFilename) - require.Nil(t, err) - ok := certPool.AppendCertsFromPEM(data) - require.True(t, ok) - wsClient := NewTLSClient(&tls.Config{ + s.NoError(err) + ok = certPool.AppendCertsFromPEM(data) + s.True(ok) + wsClient := NewClient(WithClientTLSConfig(&tls.Config{ RootCAs: certPool, - }) + })) // Test connection without bssic auth -> error expected host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} err = wsClient.Start(u.String()) // Assert HTTP error - assert.Error(t, err) - httpErr, ok := err.(HttpConnectionError) - require.True(t, ok) - assert.Equal(t, http.StatusUnauthorized, httpErr.HttpCode) - assert.Equal(t, "401 Unauthorized", httpErr.HttpStatus) - assert.Equal(t, "websocket: bad handshake", httpErr.Message) - assert.True(t, strings.Contains(err.Error(), "http status:")) + s.Error(err) + var httpErr HttpConnectionError + ok = errors.As(err, &httpErr) + s.True(ok) + s.Equal(http.StatusUnauthorized, httpErr.HttpCode) + s.Equal("401 Unauthorized", httpErr.HttpStatus) + s.Equal("websocket: bad handshake", httpErr.Message) + s.True(strings.Contains(err.Error(), "http status:")) // Add basic auth wsClient.SetBasicAuth(authUsername, "invalidPassword") // Test connection err = wsClient.Start(u.String()) - assert.NotNil(t, err) - httpError, ok := err.(HttpConnectionError) - require.True(t, ok) - require.NotNil(t, httpError) - assert.Equal(t, http.StatusUnauthorized, httpError.HttpCode) - // Cleanup - wsServer.Stop() + s.Error(err) + var httpError HttpConnectionError + ok = errors.As(err, &httpError) + s.True(ok) + s.Equal(http.StatusUnauthorized, httpError.HttpCode) } -func TestInvalidOriginHeader(t *testing.T) { - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from client!") +func (s *WebSocketSuite) TestInvalidOriginHeader() { + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from client!") return nil, nil }) - wsServer.SetNewClientHandler(func(ws Channel) { - assert.Fail(t, "no new connection should be received from client!") + s.server.SetNewClientHandler(func(ws Channel) { + s.Fail("no new connection should be received from client!") }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Test message - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from server!") + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from server!") return nil, nil }) // Set invalid origin header - wsClient.SetHeaderValue("Origin", "example.org") + s.client.SetHeaderValue("Origin", "example.org") host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Attempt to connect and expect cross-origin error - err := wsClient.Start(u.String()) - require.Error(t, err) - httpErr, ok := err.(HttpConnectionError) - require.True(t, ok) - assert.Equal(t, http.StatusForbidden, httpErr.HttpCode) - assert.Equal(t, http.StatusForbidden, httpErr.HttpCode) - assert.Equal(t, "websocket: bad handshake", httpErr.Message) - // Cleanup - wsServer.Stop() + err := s.client.Start(u.String()) + s.Error(err) + var httpErr HttpConnectionError + ok := errors.As(err, &httpErr) + s.True(ok) + s.Equal(http.StatusForbidden, httpErr.HttpCode) + s.Equal("websocket: bad handshake", httpErr.Message) } -func TestCustomOriginHeaderHandler(t *testing.T) { +func (s *WebSocketSuite) TestCustomOriginHeaderHandler() { origin := "example.org" - connected := make(chan bool) - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from client!") + connected := make(chan struct{}) + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from client!") return nil, nil }) - wsServer.SetNewClientHandler(func(ws Channel) { - connected <- true + s.server.SetNewClientHandler(func(ws Channel) { + connected <- struct{}{} }) - wsServer.SetCheckOriginHandler(func(r *http.Request) bool { + s.server.SetCheckOriginHandler(func(r *http.Request) bool { return r.Header.Get("Origin") == origin }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Test message - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from server!") + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from server!") return nil, nil }) // Set invalid origin header (not example.org) - wsClient.SetHeaderValue("Origin", "localhost") + s.client.SetHeaderValue("Origin", "localhost") host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Attempt to connect and expect cross-origin error - err := wsClient.Start(u.String()) - require.Error(t, err) - httpErr, ok := err.(HttpConnectionError) - require.True(t, ok) - assert.Equal(t, http.StatusForbidden, httpErr.HttpCode) - assert.Equal(t, http.StatusForbidden, httpErr.HttpCode) - assert.Equal(t, "websocket: bad handshake", httpErr.Message) + err := s.client.Start(u.String()) + s.Error(err) + var httpErr HttpConnectionError + ok := errors.As(err, &httpErr) + s.True(ok) + s.Equal(http.StatusForbidden, httpErr.HttpCode) + s.Equal("websocket: bad handshake", httpErr.Message) // Re-attempt with correct header - wsClient.SetHeaderValue("Origin", "example.org") - err = wsClient.Start(u.String()) - require.NoError(t, err) + s.client.SetHeaderValue("Origin", "example.org") + err = s.client.Start(u.String()) + s.NoError(err) result := <-connected - assert.True(t, result) - // Cleanup - wsServer.Stop() + s.NotNil(result) } -func TestCustomCheckClientHandler(t *testing.T) { +func (s *WebSocketSuite) TestCustomCheckClientHandler() { invalidTestPath := "/ws/invalid-testws" id := path.Base(testPath) - connected := make(chan bool) - wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from client!") + connected := make(chan struct{}) + s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from client!") return nil, nil }) - wsServer.SetNewClientHandler(func(ws Channel) { - connected <- true + s.server.SetNewClientHandler(func(ws Channel) { + connected <- struct{}{} }) - wsServer.SetCheckClientHandler(func(clientId string, r *http.Request) bool { + s.server.SetCheckClientHandler(func(clientId string, r *http.Request) bool { return id == clientId }) - go wsServer.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Test message - wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) { - assert.Fail(t, "no message should be received from server!") + s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) { + s.Fail("no message should be received from server!") return nil, nil }) @@ -731,384 +891,400 @@ func TestCustomCheckClientHandler(t *testing.T) { // Set invalid client (not /ws/testws) u := url.URL{Scheme: "ws", Host: host, Path: invalidTestPath} // Attempt to connect and expect invalid client id error - err := wsClient.Start(u.String()) - require.Error(t, err) - httpErr, ok := err.(HttpConnectionError) - require.True(t, ok) - assert.Equal(t, http.StatusUnauthorized, httpErr.HttpCode) - assert.Equal(t, "websocket: bad handshake", httpErr.Message) + err := s.client.Start(u.String()) + s.Error(err) + var httpErr HttpConnectionError + ok := errors.As(err, &httpErr) + s.True(ok) + s.Equal(http.StatusUnauthorized, httpErr.HttpCode) + s.Equal("websocket: bad handshake", httpErr.Message) // Re-attempt with correct client id u = url.URL{Scheme: "ws", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - require.NoError(t, err) - result := <-connected - assert.True(t, result) - // Cleanup - wsServer.Stop() + err = s.client.Start(u.String()) + s.NoError(err) + select { + case result := <-connected: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client to connect") + } } -func TestValidClientTLSCertificate(t *testing.T) { +func (s *WebSocketSuite) TestValidClientTLSCertificate() { + var ok bool // Create self-signed TLS certificate clientCertFilename := "/tmp/client.pem" clientKeyFilename := "/tmp/client_key.pem" + // TODO: replace with FiloSottile's lib err := createTLSCertificate(clientCertFilename, clientKeyFilename, "localhost", nil, nil) + s.NoError(err) defer os.Remove(clientCertFilename) defer os.Remove(clientKeyFilename) - require.Nil(t, err) serverCertFilename := "/tmp/cert.pem" serverKeyFilename := "/tmp/key.pem" err = createTLSCertificate(serverCertFilename, serverKeyFilename, "localhost", nil, nil) - require.Nil(t, err) + s.NoError(err) defer os.Remove(serverCertFilename) defer os.Remove(serverKeyFilename) // Create TLS server with self-signed certificate certPool := x509.NewCertPool() data, err := os.ReadFile(clientCertFilename) - require.Nil(t, err) - ok := certPool.AppendCertsFromPEM(data) - require.True(t, ok) - wsServer := NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ + s.NoError(err) + ok = certPool.AppendCertsFromPEM(data) + s.True(ok) + tlsServer := NewServer(WithServerTLSConfig(serverCertFilename, serverKeyFilename, &tls.Config{ ClientCAs: certPool, ClientAuth: tls.RequireAndVerifyClientCert, - }) + })) + s.server, ok = tlsServer.(*server) + s.True(ok) // Add basic auth handler - connected := make(chan bool) - wsServer.SetNewClientHandler(func(ws Channel) { - connected <- true + connected := make(chan struct{}) + s.server.SetNewClientHandler(func(ws Channel) { + connected <- struct{}{} }) // Run server - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Create TLS client certPool = x509.NewCertPool() data, err = os.ReadFile(serverCertFilename) - require.Nil(t, err) + s.NoError(err) ok = certPool.AppendCertsFromPEM(data) - require.True(t, ok) + s.True(ok) loadedCert, err := tls.LoadX509KeyPair(clientCertFilename, clientKeyFilename) - require.Nil(t, err) - wsClient := NewTLSClient(&tls.Config{ + s.NoError(err) + tlsClient := NewClient(WithClientTLSConfig(&tls.Config{ RootCAs: certPool, Certificates: []tls.Certificate{loadedCert}, - }) - wsClient.SetRequestedSubProtocol(defaultSubProtocol) + })) + s.client, ok = tlsClient.(*client) + s.True(ok) + s.client.SetRequestedSubProtocol(defaultSubProtocol) // Test connection host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - assert.Nil(t, err) - result := <-connected - assert.True(t, result) - // Cleanup - wsServer.Stop() + err = s.client.Start(u.String()) + s.NoError(err) + select { + case result := <-connected: + s.NotNil(result) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client to connect") + } } -func TestInvalidClientTLSCertificate(t *testing.T) { +func (s *WebSocketSuite) TestInvalidClientTLSCertificate() { + var ok bool // Create self-signed TLS certificate clientCertFilename := "/tmp/client.pem" clientKeyFilename := "/tmp/client_key.pem" + // TODO: replace with FiloSottile's lib err := createTLSCertificate(clientCertFilename, clientKeyFilename, "localhost", nil, nil) + s.NoError(err) defer os.Remove(clientCertFilename) defer os.Remove(clientKeyFilename) - require.Nil(t, err) serverCertFilename := "/tmp/cert.pem" serverKeyFilename := "/tmp/key.pem" err = createTLSCertificate(serverCertFilename, serverKeyFilename, "localhost", nil, nil) - require.Nil(t, err) + s.NoError(err) defer os.Remove(serverCertFilename) defer os.Remove(serverKeyFilename) // Create TLS server with self-signed certificate certPool := x509.NewCertPool() data, err := os.ReadFile(serverCertFilename) - require.Nil(t, err) - ok := certPool.AppendCertsFromPEM(data) - require.True(t, ok) - wsServer := NewTLSServer(serverCertFilename, serverKeyFilename, &tls.Config{ + s.NoError(err) + ok = certPool.AppendCertsFromPEM(data) + s.True(ok) + tlsServer := NewServer(WithServerTLSConfig(serverCertFilename, serverKeyFilename, &tls.Config{ ClientCAs: certPool, // Contains server certificate as allowed client CA ClientAuth: tls.RequireAndVerifyClientCert, // Requires client certificate signed by allowed CA (server) - }) + })) + s.server, ok = tlsServer.(*server) + s.True(ok) // Add basic auth handler - wsServer.SetNewClientHandler(func(ws Channel) { + s.server.SetNewClientHandler(func(ws Channel) { // Should never reach this - t.Fail() + s.Fail("no new connection should be received from client!") }) // Run server - go wsServer.Start(serverPort, serverPath) - time.Sleep(200 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Create TLS client certPool = x509.NewCertPool() data, err = os.ReadFile(serverCertFilename) - require.Nil(t, err) + s.NoError(err) ok = certPool.AppendCertsFromPEM(data) - require.True(t, ok) + s.True(ok) loadedCert, err := tls.LoadX509KeyPair(clientCertFilename, clientKeyFilename) - require.Nil(t, err) - wsClient := NewTLSClient(&tls.Config{ + s.NoError(err) + tlsClient := NewClient(WithClientTLSConfig(&tls.Config{ RootCAs: certPool, // Contains server certificate as allowed server CA Certificates: []tls.Certificate{loadedCert}, // Contains self-signed client certificate. Will be rejected by server - }) - wsClient.SetRequestedSubProtocol(defaultSubProtocol) + })) + s.client, ok = tlsClient.(*client) + s.True(ok) + s.client.SetRequestedSubProtocol(defaultSubProtocol) // Test connection host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "wss", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - assert.NotNil(t, err) - netError, ok := err.(net.Error) - require.True(t, ok) - assert.Equal(t, "remote error: tls: unknown certificate authority", netError.Error()) // tls.alertUnknownCA = 48 - // Cleanup - wsServer.Stop() + err = s.client.Start(u.String()) + s.Error(err) + var netError net.Error + ok = errors.As(err, &netError) + s.True(ok) + s.Equal("remote error: tls: unknown certificate authority", netError.Error()) // tls.alertUnknownCA = 48 } -func TestUnsupportedSubProtocol(t *testing.T) { - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { +func (s *WebSocketSuite) TestUnsupportedSubProtocol() { + s.server.SetNewClientHandler(func(ws Channel) { }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { }) - wsServer.AddSupportedSubprotocol(defaultSubProtocol) - assert.Len(t, wsServer.upgrader.Subprotocols, 1) - // Test duplicate subprotocol - wsServer.AddSupportedSubprotocol(defaultSubProtocol) - assert.Len(t, wsServer.upgrader.Subprotocols, 1) + s.server.AddSupportedSubprotocol(defaultSubProtocol) + s.Len(s.server.upgrader.Subprotocols, 1) + // Test duplicate sub-protocol + s.server.AddSupportedSubprotocol(defaultSubProtocol) + s.Len(s.server.upgrader.Subprotocols, 1) // Start server - go wsServer.Start(serverPort, serverPath) - time.Sleep(1 * time.Second) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Setup client disconnectC := make(chan struct{}) - wsClient := newWebsocketClient(t, nil) - wsClient.SetDisconnectedHandler(func(err error) { - require.IsType(t, &websocket.CloseError{}, err) - wsErr, _ := err.(*websocket.CloseError) - assert.Equal(t, websocket.CloseProtocolError, wsErr.Code) - assert.Equal(t, "invalid or unsupported subprotocol", wsErr.Text) - wsClient.SetDisconnectedHandler(nil) - disconnectC <- struct{}{} - }) - // Set invalid subprotocol - wsClient.AddOption(func(dialer *websocket.Dialer) { + s.client.SetDisconnectedHandler(func(err error) { + var wsErr *websocket.CloseError + ok := s.ErrorAs(err, &wsErr) + s.True(ok) + s.Equal(websocket.CloseProtocolError, wsErr.Code) + s.Equal("invalid or unsupported subprotocol", wsErr.Text) + s.client.SetDisconnectedHandler(nil) + close(disconnectC) + }) + // Set invalid sub-protocol + s.client.AddOption(func(dialer *websocket.Dialer) { dialer.Subprotocols = []string{"unsupportedSubProto"} }) // Test host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - assert.NoError(t, err) + err := s.client.Start(u.String()) + s.NoError(err) // Expect connection to be closed directly after start - _, ok := <-disconnectC - assert.True(t, ok) - // Cleanup - wsServer.Stop() + select { + case _, ok := <-disconnectC: + s.False(ok) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client disconnect") + } } -func TestSetServerTimeoutConfig(t *testing.T) { - disconnected := make(chan bool) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { +func (s *WebSocketSuite) TestSetServerTimeoutConfig() { + disconnected := make(chan struct{}) + s.server.SetNewClientHandler(func(ws Channel) { }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { // TODO: check for error with upcoming API - disconnected <- true + close(disconnected) }) // Setting server timeout config := NewServerTimeoutConfig() - pingWait := 2 * time.Second - writeWait := 2 * time.Second + pingWait := 400 * time.Millisecond + writeWait := 500 * time.Millisecond config.PingWait = pingWait config.WriteWait = writeWait - wsServer.SetTimeoutConfig(config) + s.server.SetTimeoutConfig(config) // Start server - go wsServer.Start(serverPort, serverPath) - time.Sleep(500 * time.Millisecond) - assert.Equal(t, wsServer.timeoutConfig.PingWait, pingWait) - assert.Equal(t, wsServer.timeoutConfig.WriteWait, writeWait) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + s.Equal(s.server.timeoutConfig.PingWait, pingWait) + s.Equal(s.server.timeoutConfig.WriteWait, writeWait) // Run test - wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - assert.NoError(t, err) - result := <-disconnected - assert.True(t, result) - // Cleanup - wsClient.Stop() - wsServer.Stop() + err := s.client.Start(u.String()) + s.NoError(err) + select { + case _, ok := <-disconnected: + s.False(ok) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client disconnect") + } } -func TestSetClientTimeoutConfig(t *testing.T) { - disconnected := make(chan bool) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { +func (s *WebSocketSuite) TestSetClientTimeoutConfig() { + disconnected := make(chan struct{}) + s.server.SetNewClientHandler(func(ws Channel) { }) - wsServer.SetDisconnectedClientHandler(func(ws Channel) { + s.server.SetDisconnectedClientHandler(func(ws Channel) { // TODO: check for error with upcoming API - disconnected <- true + close(disconnected) }) // Start server - go wsServer.Start(serverPort, serverPath) - time.Sleep(200 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Run test - wsClient := newWebsocketClient(t, nil) host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} // Set client timeout config := NewClientTimeoutConfig() handshakeTimeout := 1 * time.Nanosecond // Very low timeout, handshake will fail - pongWait := 2 * time.Second - writeWait := 2 * time.Second - pingPeriod := 5 * time.Second + writeWait := 1 * time.Second + // Ping period > pong wait, this is a nonsensical config that will trigger a pong timeout + pingPeriod := 3 * time.Second + pongWait := 500 * time.Millisecond config.PongWait = pongWait config.HandshakeTimeout = handshakeTimeout config.WriteWait = writeWait config.PingPeriod = pingPeriod - wsClient.SetTimeoutConfig(config) + s.client.SetTimeoutConfig(config) // Start client and expect handshake error - err := wsClient.Start(u.String()) - opError, ok := err.(*net.OpError) - require.True(t, ok) - assert.Equal(t, "dial", opError.Op) - assert.True(t, opError.Timeout()) - assert.Error(t, opError.Err, "i/o timeout") + err := s.client.Start(u.String()) + var opError *net.OpError + ok := s.ErrorAs(err, &opError) + s.True(ok) + s.Equal("dial", opError.Op) + s.True(opError.Timeout()) + s.Error(opError.Err, "i/o timeout") + // Reset handshake to reasonable value config.HandshakeTimeout = defaultHandshakeTimeout - wsClient.SetTimeoutConfig(config) + s.client.SetTimeoutConfig(config) // Start client - err = wsClient.Start(u.String()) - require.NoError(t, err) - assert.Equal(t, wsClient.timeoutConfig.PongWait, pongWait) - assert.Equal(t, wsClient.timeoutConfig.WriteWait, writeWait) - assert.Equal(t, wsClient.timeoutConfig.PingPeriod, pingPeriod) - result := <-disconnected - assert.True(t, result) - // Cleanup - wsClient.Stop() - wsServer.Stop() + err = s.client.Start(u.String()) + s.NoError(err) + s.Equal(s.client.timeoutConfig.PongWait, pongWait) + s.Equal(s.client.timeoutConfig.WriteWait, writeWait) + s.Equal(s.client.timeoutConfig.PingPeriod, pingPeriod) + select { + case _, closed := <-disconnected: + s.False(closed) + case <-time.After(1 * time.Second): + s.Fail("timeout waiting for client disconnect") + } } -func TestServerErrors(t *testing.T) { - triggerC := make(chan bool, 1) - finishC := make(chan bool, 1) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { - triggerC <- true +func (s *WebSocketSuite) TestServerErrors() { + triggerC := make(chan struct{}, 1) + finishC := make(chan struct{}, 1) + defer close(finishC) + s.server.SetNewClientHandler(func(ws Channel) { + triggerC <- struct{}{} }) // Intercept errors asynchronously - assert.Nil(t, wsServer.errC) + s.Nil(s.server.errC) go func() { for { select { - case err, ok := <-wsServer.Errors(): - triggerC <- true + case err, ok := <-s.server.Errors(): + triggerC <- struct{}{} if ok { - assert.Error(t, err) + s.Error(err) } case <-finishC: return } } }() - wsServer.SetMessageHandler(func(ws Channel, data []byte) error { + s.server.SetMessageHandler(func(ws Channel, data []byte) error { return fmt.Errorf("this is a dummy error") }) // Will trigger an out-of-bound error time.Sleep(50 * time.Millisecond) - wsServer.Stop() + s.server.Stop() r := <-triggerC - assert.True(t, r) + s.NotNil(r) // Start server for real - wsServer.httpServer = &http.Server{} - go wsServer.Start(serverPort, serverPath) - time.Sleep(200 * time.Millisecond) - // Create and connect client - wsClient := newWebsocketClient(t, nil) + s.server.httpServer = &http.Server{} + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) + // Connect client host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err := wsClient.Start(u.String()) - require.NoError(t, err) + err := s.client.Start(u.String()) + s.NoError(err) // Wait for new client callback r = <-triggerC - require.True(t, r) + s.NotNil(r) // Send a dummy message and expect error on server side - err = wsClient.Write([]byte("dummy message")) - require.NoError(t, err) + err = s.client.Write([]byte("dummy message")) + s.NoError(err) r = <-triggerC - assert.True(t, r) + s.NotNil(r) // Send message to non-existing client - err = wsServer.Write("fakeId", []byte("dummy response")) - require.Error(t, err) + err = s.server.Write("fakeId", []byte("dummy response")) + s.Error(err) // Send unexpected close message and wait for error to be thrown - err = wsClient.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) - assert.NoError(t, err) - <-triggerC + err = s.client.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) + s.NoError(err) + r = <-triggerC + s.NotNil(r) // Stop and wait for errors channel cleanup - wsServer.Stop() + s.server.Stop() r = <-triggerC - assert.True(t, r) - close(finishC) + s.NotNil(r) } -func TestClientErrors(t *testing.T) { - triggerC := make(chan bool, 1) - finishC := make(chan bool, 1) - wsServer := newWebsocketServer(t, nil) - wsServer.SetNewClientHandler(func(ws Channel) { - triggerC <- true +func (s *WebSocketSuite) TestClientErrors() { + triggerC := make(chan struct{}, 1) + finishC := make(chan struct{}, 1) + defer close(finishC) + s.server.SetNewClientHandler(func(ws Channel) { + triggerC <- struct{}{} }) - wsClient := newWebsocketClient(t, nil) - wsClient.SetMessageHandler(func(data []byte) error { + s.client.SetMessageHandler(func(data []byte) error { return fmt.Errorf("this is a dummy error") }) // Intercept errors asynchronously - assert.Nil(t, wsClient.errC) + s.Nil(s.client.errC) go func() { for { select { - case err, ok := <-wsClient.Errors(): - triggerC <- true + case err, ok := <-s.client.Errors(): + triggerC <- struct{}{} if ok { - assert.Error(t, err) + s.Error(err) } case <-finishC: return } } }() - go wsServer.Start(serverPort, serverPath) - time.Sleep(200 * time.Millisecond) + go s.server.Start(serverPort, serverPath) + time.Sleep(100 * time.Millisecond) // Attempt to write a message without being connected - err := wsClient.Write([]byte("dummy message")) - require.Error(t, err) + err := s.client.Write([]byte("dummy message")) + s.Error(err) // Connect client host := fmt.Sprintf("localhost:%v", serverPort) u := url.URL{Scheme: "ws", Host: host, Path: testPath} - err = wsClient.Start(u.String()) - require.NoError(t, err) + err = s.client.Start(u.String()) + s.NoError(err) // Wait for new client callback r := <-triggerC - require.True(t, r) + s.NotNil(r) // Send a dummy message and expect error on client side - err = wsServer.Write(path.Base(testPath), []byte("dummy message")) - require.NotNil(t, t, err) + err = s.server.Write(path.Base(testPath), []byte("dummy message")) + s.NoError(err) r = <-triggerC - assert.True(t, r) + s.NotNil(r) // Send unexpected close message and wait for error to be thrown - conn := wsServer.connections[path.Base(testPath)] - require.NotNil(t, conn) - err = conn.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) - assert.NoError(t, err) + conn := s.server.connections[path.Base(testPath)] + s.NotNil(conn) + err = conn.WriteManual(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) + //err = conn.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "")) + s.NoError(err) r = <-triggerC - require.True(t, r) + s.NotNil(r) // Stop server and client and wait for errors channel cleanup - wsServer.Stop() - wsClient.Stop() + s.server.Stop() + s.client.Stop() r = <-triggerC - require.True(t, r) - close(finishC) + s.NotNil(r) } // Utility functions @@ -1238,3 +1414,7 @@ func createTLSCertificate(certificateFilename string, keyFilename string, cn str } return nil } + +func TestWebSockets(t *testing.T) { + suite.Run(t, new(WebSocketSuite)) +}