Skip to content

Commit 4682bb1

Browse files
simplify ReadJSON mock. refactor controller a bit
1 parent 8496af3 commit 4682bb1

File tree

2 files changed

+16
-22
lines changed

2 files changed

+16
-22
lines changed

engine/access/rest/websockets/controller.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func NewWebSocketController(
3535
logger: logger.With().Str("component", "websocket-controller").Logger(),
3636
config: config,
3737
conn: conn,
38-
communicationChannel: make(chan interface{}, 10), //TODO: should it be buffered chan?
38+
communicationChannel: make(chan interface{}), //TODO: should it be buffered chan?
3939
dataProviders: concurrentmap.New[uuid.UUID, dp.DataProvider](),
4040
dataProvidersFactory: factory,
4141
shutdownOnce: sync.Once{},
@@ -46,6 +46,7 @@ func NewWebSocketController(
4646
func (c *Controller) HandleConnection(ctx context.Context) {
4747
//TODO: configure the connection with ping-pong and deadlines
4848
//TODO: spin up a response limit tracker routine
49+
defer c.shutdownConnection()
4950
go c.readMessages(ctx)
5051
c.writeMessages(ctx)
5152
}
@@ -54,8 +55,6 @@ func (c *Controller) HandleConnection(ctx context.Context) {
5455
// The communication channel is filled by data providers. Besides, the response limit tracker is involved in
5556
// write message regulation
5657
func (c *Controller) writeMessages(ctx context.Context) {
57-
defer c.shutdownConnection()
58-
5958
for {
6059
select {
6160
case <-ctx.Done():
@@ -86,8 +85,6 @@ func (c *Controller) writeMessages(ctx context.Context) {
8685
// readMessages continuously reads messages from a client WebSocket connection,
8786
// processes each message, and handles actions based on the message type.
8887
func (c *Controller) readMessages(ctx context.Context) {
89-
defer c.shutdownConnection()
90-
9188
for {
9289
msg, err := c.readMessage()
9390
if err != nil {
@@ -188,7 +185,12 @@ func (c *Controller) handleSubscribe(ctx context.Context, msg models.SubscribeMe
188185
Topic: dp.Topic(),
189186
ID: dp.ID().String(),
190187
}
191-
c.communicationChannel <- response
188+
189+
select {
190+
case <-ctx.Done():
191+
return
192+
case c.communicationChannel <- response:
193+
}
192194

193195
dp.Run(ctx)
194196
}
@@ -216,8 +218,6 @@ func (c *Controller) handleListSubscriptions(ctx context.Context, msg models.Lis
216218
func (c *Controller) shutdownConnection() {
217219
c.shutdownOnce.Do(func() {
218220
defer func() {
219-
close(c.communicationChannel)
220-
221221
if err := c.conn.Close(); err != nil {
222222
c.logger.Warn().Err(err).Msg("error closing connection")
223223
}

engine/access/rest/websockets/controller_test.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (s *WsControllerSuite) TestSubscribeRequest() {
5353
Run(func(args mock.Arguments) {}).
5454
Once()
5555

56-
requestMessage := models.SubscribeMessageRequest{
56+
subscribeRequest := models.SubscribeMessageRequest{
5757
BaseMessageRequest: models.BaseMessageRequest{Action: "subscribe"},
5858
Topic: "blocks",
5959
Arguments: nil,
@@ -63,11 +63,11 @@ func (s *WsControllerSuite) TestSubscribeRequest() {
6363
conn.
6464
On("ReadJSON", mock.Anything).
6565
Run(func(args mock.Arguments) {
66-
reqMsg, ok := args.Get(0).(*json.RawMessage)
66+
requestMsg, ok := args.Get(0).(*json.RawMessage)
6767
require.True(t, ok)
68-
msg, err := json.Marshal(requestMessage)
68+
subscribeRequestMessage, err := json.Marshal(subscribeRequest)
6969
require.NoError(t, err)
70-
*reqMsg = msg
70+
*requestMsg = subscribeRequestMessage
7171
}).
7272
Return(nil).
7373
Once()
@@ -90,11 +90,8 @@ func (s *WsControllerSuite) TestSubscribeRequest() {
9090
conn.
9191
On("ReadJSON", mock.Anything).
9292
Return(func(interface{}) error {
93-
_, ok := <-done
94-
if !ok {
95-
return websocket.ErrCloseSent
96-
}
97-
return nil
93+
<-done
94+
return websocket.ErrCloseSent
9895
})
9996

10097
controller.HandleConnection(context.Background())
@@ -231,11 +228,8 @@ func (s *WsControllerSuite) expectSubscriptionRequest(conn *connmock.WebsocketCo
231228
conn.
232229
On("ReadJSON", mock.Anything).
233230
Return(func(msg interface{}) error {
234-
_, ok := <-done
235-
if !ok {
236-
return websocket.ErrCloseSent
237-
}
238-
return nil
231+
<-done
232+
return websocket.ErrCloseSent
239233
})
240234
}
241235

0 commit comments

Comments
 (0)