Skip to content

Commit a574fb6

Browse files
Add unit test for websocket controller
* Add unit test for websocket controller * Add mock for websocket connection * Add mock for data provider * Add mock for data provider factory * Add mock for websocket connection The WebSocket Controller interacts with: 1. Data Provider: Supplies data to the controller. 2. WebSocket Connection: Handles communication with the client. To properly test the controller's logic, we mock these interactions. Since the controller runs two parallel routines (reader and writer), the tests also ensure both can shut down cleanly. A done channel is used in the tests to coordinate this process.
1 parent a3676ba commit a574fb6

File tree

13 files changed

+530
-163
lines changed

13 files changed

+530
-163
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ generate-mocks: install-mock-generators
214214
mockery --name 'Storage' --dir=module/executiondatasync/tracker --case=underscore --output="module/executiondatasync/tracker/mock" --outpkg="mocktracker"
215215
mockery --name 'ScriptExecutor' --dir=module/execution --case=underscore --output="module/execution/mock" --outpkg="mock"
216216
mockery --name 'StorageSnapshot' --dir=fvm/storage/snapshot --case=underscore --output="fvm/storage/snapshot/mock" --outpkg="mock"
217+
mockery --name 'DataProvider' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock"
218+
mockery --name 'Factory' --dir=engine/access/rest/websockets/data_provider --case=underscore --output="engine/access/rest/websockets/data_provider/mock" --outpkg="mock"
219+
mockery --name 'WebsocketConnection' --dir=engine/access/rest/websockets --case=underscore --output="engine/access/rest/websockets/mock" --outpkg="mock"
217220

218221
#temporarily make insecure/ a non-module to allow mockery to create mocks
219222
mv insecure/go.mod insecure/go2.mod

engine/access/rest/router/router.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
flowhttp "github.com/onflow/flow-go/engine/access/rest/http"
1515
"github.com/onflow/flow-go/engine/access/rest/http/models"
1616
"github.com/onflow/flow-go/engine/access/rest/websockets"
17+
"github.com/onflow/flow-go/engine/access/rest/websockets/data_provider"
1718
legacyws "github.com/onflow/flow-go/engine/access/rest/websockets/legacy"
1819
"github.com/onflow/flow-go/engine/access/state_stream"
1920
"github.com/onflow/flow-go/engine/access/state_stream/backend"
@@ -93,7 +94,8 @@ func (b *RouterBuilder) AddWebsocketsRoute(
9394
streamConfig backend.Config,
9495
maxRequestSize int64,
9596
) *RouterBuilder {
96-
handler := websockets.NewWebSocketHandler(b.logger, config, chain, streamApi, streamConfig, maxRequestSize)
97+
factory := data_provider.NewDataProviderFactory(b.logger, streamApi, streamConfig)
98+
handler := websockets.NewWebSocketHandler(b.logger, config, chain, factory, maxRequestSize)
9799
b.v1SubRouter.
98100
Methods(http.MethodGet).
99101
Path("/ws").
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package websockets
2+
3+
import (
4+
"github.com/gorilla/websocket"
5+
)
6+
7+
type WebsocketConnection interface {
8+
ReadJSON(v interface{}) error
9+
WriteJSON(v interface{}) error
10+
Close() error
11+
}
12+
13+
type GorillaWebsocketConnection struct {
14+
conn *websocket.Conn
15+
}
16+
17+
func NewGorillaWebsocketConnection(conn *websocket.Conn) *GorillaWebsocketConnection {
18+
return &GorillaWebsocketConnection{
19+
conn: conn,
20+
}
21+
}
22+
23+
var _ WebsocketConnection = (*GorillaWebsocketConnection)(nil)
24+
25+
func (m *GorillaWebsocketConnection) ReadJSON(v interface{}) error {
26+
return m.conn.ReadJSON(v)
27+
}
28+
29+
func (m *GorillaWebsocketConnection) WriteJSON(v interface{}) error {
30+
return m.conn.WriteJSON(v)
31+
}
32+
33+
func (m *GorillaWebsocketConnection) SetCloseHandler(handler func(code int, text string) error) {
34+
m.conn.SetCloseHandler(handler)
35+
}
36+
37+
func (m *GorillaWebsocketConnection) Close() error {
38+
return m.conn.Close()
39+
}

engine/access/rest/websockets/controller.go

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,102 +3,119 @@ package websockets
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
8+
"sync"
79

810
"github.com/google/uuid"
911
"github.com/gorilla/websocket"
1012
"github.com/rs/zerolog"
1113

1214
dp "github.com/onflow/flow-go/engine/access/rest/websockets/data_provider"
1315
"github.com/onflow/flow-go/engine/access/rest/websockets/models"
14-
"github.com/onflow/flow-go/engine/access/state_stream"
15-
"github.com/onflow/flow-go/engine/access/state_stream/backend"
1616
"github.com/onflow/flow-go/utils/concurrentmap"
1717
)
1818

19+
var ErrEmptyMessage = errors.New("empty message")
20+
1921
type Controller struct {
2022
logger zerolog.Logger
2123
config Config
22-
conn *websocket.Conn
24+
conn WebsocketConnection
2325
communicationChannel chan interface{}
2426
dataProviders *concurrentmap.Map[uuid.UUID, dp.DataProvider]
25-
dataProvidersFactory *dp.Factory
27+
dataProvidersFactory dp.Factory
28+
shutdownOnce sync.Once
2629
}
2730

2831
func NewWebSocketController(
2932
logger zerolog.Logger,
3033
config Config,
31-
streamApi state_stream.API,
32-
streamConfig backend.Config,
33-
conn *websocket.Conn,
34+
factory dp.Factory,
35+
conn WebsocketConnection,
3436
) *Controller {
3537
return &Controller{
3638
logger: logger.With().Str("component", "websocket-controller").Logger(),
3739
config: config,
3840
conn: conn,
39-
communicationChannel: make(chan interface{}), //TODO: should it be buffered chan?
41+
communicationChannel: make(chan interface{}, 10), //TODO: should it be buffered chan?
4042
dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](),
41-
dataProvidersFactory: dp.NewDataProviderFactory(logger, streamApi, streamConfig),
43+
dataProvidersFactory: factory,
44+
shutdownOnce: sync.Once{},
4245
}
4346
}
4447

4548
// HandleConnection manages the WebSocket connection, adding context and error handling.
4649
func (c *Controller) HandleConnection(ctx context.Context) {
4750
//TODO: configure the connection with ping-pong and deadlines
4851
//TODO: spin up a response limit tracker routine
49-
go c.readMessagesFromClient(ctx)
50-
c.writeMessagesToClient(ctx)
52+
go c.readMessages(ctx)
53+
c.writeMessages(ctx)
5154
}
5255

53-
// writeMessagesToClient reads a messages from communication channel and passes them on to a client WebSocket connection.
56+
// writeMessages reads a messages from communication channel and passes them on to a client WebSocket connection.
5457
// The communication channel is filled by data providers. Besides, the response limit tracker is involved in
5558
// write message regulation
56-
func (c *Controller) writeMessagesToClient(ctx context.Context) {
57-
//TODO: can it run forever? maybe we should cancel the ctx in the reader routine
59+
func (c *Controller) writeMessages(ctx context.Context) {
60+
defer c.shutdownConnection()
61+
5862
for {
5963
select {
6064
case <-ctx.Done():
6165
return
62-
case msg := <-c.communicationChannel:
63-
// TODO: handle 'response per second' limits
66+
case msg, ok := <-c.communicationChannel:
67+
if !ok {
68+
return
69+
}
70+
c.logger.Debug().Msgf("read message from communication channel: %s", msg)
6471

72+
// TODO: handle 'response per second' limits
6573
err := c.conn.WriteJSON(msg)
6674
if err != nil {
75+
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) ||
76+
websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
77+
return
78+
}
79+
6780
c.logger.Error().Err(err).Msg("error writing to connection")
81+
return
6882
}
83+
84+
c.logger.Debug().Msg("written message to client")
6985
}
7086
}
7187
}
7288

73-
// readMessagesFromClient continuously reads messages from a client WebSocket connection,
89+
// readMessages continuously reads messages from a client WebSocket connection,
7490
// processes each message, and handles actions based on the message type.
75-
func (c *Controller) readMessagesFromClient(ctx context.Context) {
91+
func (c *Controller) readMessages(ctx context.Context) {
7692
defer c.shutdownConnection()
7793

7894
for {
79-
select {
80-
case <-ctx.Done():
81-
c.logger.Info().Msg("context canceled, stopping read message loop")
82-
return
83-
default:
84-
msg, err := c.readMessage()
85-
if err != nil {
86-
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) {
87-
return
88-
}
89-
c.logger.Warn().Err(err).Msg("error reading message from client")
95+
msg, err := c.readMessage()
96+
if err != nil {
97+
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) ||
98+
websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
9099
return
100+
} else if errors.Is(err, ErrEmptyMessage) {
101+
continue
91102
}
92103

93-
baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg)
94-
if err != nil {
95-
c.logger.Debug().Err(err).Msg("error parsing and validating client message")
96-
return
97-
}
104+
c.logger.Debug().Err(err).Msg("error reading message from client")
105+
continue
106+
}
98107

99-
if err := c.handleAction(ctx, validatedMsg); err != nil {
100-
c.logger.Warn().Err(err).Str("action", baseMsg.Action).Msg("error handling action")
101-
}
108+
baseMsg, validatedMsg, err := c.parseAndValidateMessage(msg)
109+
if err != nil {
110+
c.logger.Debug().Err(err).Msg("error parsing and validating client message")
111+
//TODO: write error to error channel
112+
continue
113+
}
114+
115+
if err := c.handleAction(ctx, validatedMsg); err != nil {
116+
c.logger.Debug().Err(err).Str("action", baseMsg.Action).Msg("error handling action")
117+
//TODO: write error to error channel
118+
continue
102119
}
103120
}
104121
}
@@ -108,6 +125,11 @@ func (c *Controller) readMessage() (json.RawMessage, error) {
108125
if err := c.conn.ReadJSON(&message); err != nil {
109126
return nil, fmt.Errorf("error reading JSON from client: %w", err)
110127
}
128+
129+
if message == nil {
130+
return nil, ErrEmptyMessage
131+
}
132+
111133
return message, nil
112134
}
113135

@@ -166,10 +188,18 @@ func (c *Controller) handleAction(ctx context.Context, message interface{}) erro
166188
func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMessageRequest) {
167189
dp := c.dataProvidersFactory.NewDataProvider(c.communicationChannel, msg.Topic)
168190
c.dataProviders.Add(dp.ID(), dp)
169-
dp.Run(ctx)
170191

171-
//TODO: return OK response to client
172-
c.communicationChannel <- msg
192+
// firstly, we want to write OK response to client and only after that we can start providing actual data
193+
response := models.SubscribeMessageResponse{
194+
BaseMessageResponse: models.BaseMessageResponse{
195+
Success: true,
196+
},
197+
Topic: dp.Topic(),
198+
ID: dp.ID().String(),
199+
}
200+
c.communicationChannel <- response
201+
202+
dp.Run(ctx)
173203
}
174204

175205
func (c *Controller) handleUnsubscribe(_ context.Context, msg models.UnsubscribeMessageRequest) {
@@ -193,20 +223,24 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis
193223
}
194224

195225
func (c *Controller) shutdownConnection() {
196-
defer close(c.communicationChannel)
197-
defer func(conn *websocket.Conn) {
198-
if err := c.conn.Close(); err != nil {
199-
c.logger.Error().Err(err).Msg("error closing connection")
226+
c.shutdownOnce.Do(func() {
227+
defer close(c.communicationChannel)
228+
defer func(conn WebsocketConnection) {
229+
if err := c.conn.Close(); err != nil {
230+
c.logger.Warn().Err(err).Msg("error closing connection")
231+
}
232+
}(c.conn)
233+
234+
c.logger.Debug().Msg("shutting down connection")
235+
236+
err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error {
237+
dp.Close()
238+
return nil
239+
})
240+
if err != nil {
241+
c.logger.Error().Err(err).Msg("error closing data provider")
200242
}
201-
}(c.conn)
202243

203-
err := c.dataProviders.ForEach(func(_ uuid.UUID, dp dp.DataProvider) error {
204-
dp.Close()
205-
return nil
244+
c.dataProviders.Clear()
206245
})
207-
if err != nil {
208-
c.logger.Error().Err(err).Msg("error closing data provider")
209-
}
210-
211-
c.dataProviders.Clear()
212246
}

0 commit comments

Comments
 (0)