Skip to content

Commit b255d66

Browse files
author
Jenita
committed
feat: no error return on connection close
Signed-off-by: Jenita <[email protected]>
1 parent 829e00e commit b255d66

File tree

13 files changed

+248
-272
lines changed

13 files changed

+248
-272
lines changed

protocol/blockfetch/blockfetch.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,17 @@ func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch {
121121
return b
122122
}
123123

124-
func (b *BlockFetch) IsDone() bool {
125-
if b.Client != nil && b.Client.IsDone() {
126-
return true
127-
}
128-
if b.Server != nil && b.Server.IsDone() {
129-
return true
130-
}
131-
return false
132-
}
133-
134124
func (b *BlockFetch) HandleConnectionError(err error) error {
135125
if err == nil {
136126
return nil
137127
}
128+
// Check if protocol is done or if it's a normal connection closure
129+
if b.Client.IsDone() || b.Server.IsDone() {
130+
return nil
131+
}
132+
138133
if errors.Is(err, io.EOF) || isConnectionReset(err) {
139-
if b.IsDone() {
140-
return nil
141-
}
134+
return err
142135
}
143136
return err
144137
}

protocol/blockfetch/blockfetch_test.go

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,33 @@ func (a testAddr) String() string { return "test-addr" }
3838
// testConn implements net.Conn for testing with buffered writes
3939
type testConn struct {
4040
writeChan chan []byte
41+
closed bool
42+
closeChan chan struct{}
43+
}
44+
45+
func newTestConn() *testConn {
46+
return &testConn{
47+
writeChan: make(chan []byte, 100),
48+
closeChan: make(chan struct{}),
49+
}
4150
}
4251

4352
func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil }
4453
func (c *testConn) Write(b []byte) (n int, err error) {
45-
c.writeChan <- b
46-
return len(b), nil
54+
select {
55+
case c.writeChan <- b:
56+
return len(b), nil
57+
case <-c.closeChan:
58+
return 0, io.EOF
59+
}
60+
}
61+
func (c *testConn) Close() error {
62+
if !c.closed {
63+
close(c.closeChan)
64+
c.closed = true
65+
}
66+
return nil
4767
}
48-
func (c *testConn) Close() error { return nil }
4968
func (c *testConn) LocalAddr() net.Addr { return testAddr{} }
5069
func (c *testConn) RemoteAddr() net.Addr { return testAddr{} }
5170
func (c *testConn) SetDeadline(t time.Time) error { return nil }
@@ -54,6 +73,12 @@ func (c *testConn) SetWriteDeadline(t time.Time) error { return nil }
5473

5574
func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions {
5675
mux := muxer.New(conn)
76+
go mux.Start()
77+
go func() {
78+
<-conn.(*testConn).closeChan
79+
mux.Stop()
80+
}()
81+
5782
return protocol.ProtocolOptions{
5883
ConnectionId: connection.ConnectionId{
5984
LocalAddr: testAddr{},
@@ -65,7 +90,8 @@ func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions {
6590
}
6691

6792
func TestNewBlockFetch(t *testing.T) {
68-
conn := &testConn{writeChan: make(chan []byte, 1)}
93+
conn := newTestConn()
94+
defer conn.Close()
6995
cfg := NewConfig()
7096
bf := New(getTestProtocolOptions(conn), &cfg)
7197
assert.NotNil(t, bf.Client)
@@ -92,10 +118,17 @@ func TestConfigOptions(t *testing.T) {
92118
}
93119

94120
func TestConnectionErrorHandling(t *testing.T) {
95-
conn := &testConn{writeChan: make(chan []byte, 1)}
121+
conn := newTestConn()
122+
defer conn.Close()
96123
cfg := NewConfig()
97124
bf := New(getTestProtocolOptions(conn), &cfg)
98125

126+
// Start protocols
127+
bf.Client.Start()
128+
defer bf.Client.Stop()
129+
bf.Server.Start()
130+
defer bf.Server.Stop()
131+
99132
t.Run("Non-EOF error when not done", func(t *testing.T) {
100133
err := bf.HandleConnectionError(errors.New("test error"))
101134
assert.Error(t, err)
@@ -110,10 +143,23 @@ func TestConnectionErrorHandling(t *testing.T) {
110143
err := bf.HandleConnectionError(errors.New("connection reset by peer"))
111144
assert.Error(t, err)
112145
})
146+
147+
t.Run("EOF error when done", func(t *testing.T) {
148+
// Send done message to properly transition to done state
149+
err := bf.Client.SendMessage(NewMsgClientDone())
150+
assert.NoError(t, err)
151+
152+
// Wait for state transition
153+
time.Sleep(100 * time.Millisecond)
154+
155+
err = bf.HandleConnectionError(io.EOF)
156+
assert.NoError(t, err, "expected no error when protocol is in done state")
157+
})
113158
}
114159

115160
func TestCallbackRegistration(t *testing.T) {
116-
conn := &testConn{writeChan: make(chan []byte, 1)}
161+
conn := newTestConn()
162+
defer conn.Close()
117163

118164
t.Run("Block callback registration", func(t *testing.T) {
119165
blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error {
@@ -137,16 +183,17 @@ func TestCallbackRegistration(t *testing.T) {
137183
}
138184

139185
func TestClientMessageSending(t *testing.T) {
140-
conn := &testConn{writeChan: make(chan []byte, 1)}
186+
conn := newTestConn()
187+
defer conn.Close()
141188
cfg := NewConfig()
142189
client := NewClient(getTestProtocolOptions(conn), &cfg)
143190

144191
t.Run("Client can send messages", func(t *testing.T) {
145-
// Start the client protocol
146192
client.Start()
193+
defer client.Stop()
147194

148195
// Send a done message
149-
err := client.Protocol.SendMessage(NewMsgClientDone())
196+
err := client.SendMessage(NewMsgClientDone())
150197
assert.NoError(t, err)
151198

152199
// Verify message was written to connection
@@ -158,14 +205,3 @@ func TestClientMessageSending(t *testing.T) {
158205
}
159206
})
160207
}
161-
162-
func TestServerMessageHandling(t *testing.T) {
163-
conn := &testConn{writeChan: make(chan []byte, 1)}
164-
cfg := NewConfig()
165-
server := NewServer(getTestProtocolOptions(conn), &cfg)
166-
167-
t.Run("Server can be started", func(t *testing.T) {
168-
server.Start()
169-
170-
})
171-
}

protocol/blockfetch/client.go

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ type Client struct {
3535
blockUseCallback bool
3636
onceStart sync.Once
3737
onceStop sync.Once
38-
currentState protocol.State
39-
stateMutex sync.Mutex
4038
}
4139

4240
func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
@@ -48,7 +46,6 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
4846
config: cfg,
4947
blockChan: make(chan ledger.Block),
5048
startBatchResultChan: make(chan error),
51-
currentState: StateIdle,
5249
}
5350
c.callbackContext = CallbackContext{
5451
Client: c,
@@ -85,18 +82,6 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
8582
return c
8683
}
8784

88-
func (c *Client) IsDone() bool {
89-
c.stateMutex.Lock()
90-
defer c.stateMutex.Unlock()
91-
return c.currentState.Id == StateDone.Id
92-
}
93-
94-
func (c *Client) setState(newState protocol.State) {
95-
c.stateMutex.Lock()
96-
defer c.stateMutex.Unlock()
97-
c.currentState = newState
98-
}
99-
10085
func (c *Client) Start() {
10186
c.onceStart.Do(func() {
10287
c.Protocol.Logger().
@@ -124,12 +109,12 @@ func (c *Client) Stop() error {
124109
"protocol", ProtocolName,
125110
"connection_id", c.callbackContext.ConnectionId.String(),
126111
)
127-
msg := NewMsgClientDone()
128-
if sendErr := c.SendMessage(msg); sendErr != nil {
129-
err = sendErr
130-
return
112+
if !c.IsDone() {
113+
msg := NewMsgClientDone()
114+
if err = c.SendMessage(msg); err != nil {
115+
return
116+
}
131117
}
132-
c.setState(StateDone)
133118
})
134119
return err
135120
}
@@ -216,7 +201,7 @@ func (c *Client) messageHandler(msg protocol.Message) error {
216201
case MessageTypeBatchDone:
217202
err = c.handleBatchDone()
218203
case MessageTypeClientDone:
219-
c.setState(StateDone)
204+
return nil
220205
default:
221206
err = fmt.Errorf(
222207
"%s: received unexpected message type %d",

protocol/blockfetch/server.go

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package blockfetch
1717
import (
1818
"errors"
1919
"fmt"
20-
"sync"
2120

2221
"github.com/blinklabs-io/gouroboros/cbor"
2322
"github.com/blinklabs-io/gouroboros/protocol"
@@ -28,16 +27,13 @@ type Server struct {
2827
config *Config
2928
callbackContext CallbackContext
3029
protoOptions protocol.ProtocolOptions
31-
currentState protocol.State
32-
stateMutex sync.Mutex
3330
}
3431

3532
func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
3633
s := &Server{
3734
config: cfg,
3835
// Save this for re-use later
3936
protoOptions: protoOptions,
40-
currentState: StateIdle,
4137
}
4238
s.callbackContext = CallbackContext{
4339
Server: s,
@@ -47,18 +43,6 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
4743
return s
4844
}
4945

50-
func (s *Server) IsDone() bool {
51-
s.stateMutex.Lock()
52-
defer s.stateMutex.Unlock()
53-
return s.currentState.Id == StateDone.Id
54-
}
55-
56-
func (s *Server) setState(newState protocol.State) {
57-
s.stateMutex.Lock()
58-
defer s.stateMutex.Unlock()
59-
s.currentState = newState
60-
}
61-
6246
func (s *Server) initProtocol() {
6347
protoConfig := protocol.ProtocolConfig{
6448
Name: ProtocolName,
@@ -142,8 +126,8 @@ func (s *Server) messageHandler(msg protocol.Message) error {
142126
case MessageTypeRequestRange:
143127
err = s.handleRequestRange(msg)
144128
case MessageTypeClientDone:
145-
s.setState(StateDone)
146-
err = s.handleClientDone()
129+
// State handled automatically by base protocol
130+
return s.handleClientDone()
147131
default:
148132
err = fmt.Errorf(
149133
"%s: received unexpected message type %d",

protocol/chainsync/chainsync.go

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,10 @@ var PipelineIsNotEmpty = func(context any, msg protocol.Message) bool {
195195

196196
// ChainSync is a wrapper object that holds the client and server instances
197197
type ChainSync struct {
198-
Client *Client
199-
Server *Server
200-
stateMutex sync.Mutex
201-
currentState protocol.State
198+
Client *Client
199+
Server *Server
202200
}
203201

204-
// Config is used to configure the ChainSync protocol instance
205202
type Config struct {
206203
RollBackwardFunc RollBackwardFunc
207204
RollForwardFunc RollForwardFunc
@@ -335,26 +332,20 @@ func WithRecvQueueSize(size int) ChainSyncOptionFunc {
335332
}
336333
}
337334

338-
// HandleConnectionError handles connection errors and determines if they should be ignored
339335
func (c *ChainSync) HandleConnectionError(err error) error {
340336
if err == nil {
341337
return nil
342338
}
339+
if c.Client.IsDone() || c.Server.IsDone() {
340+
return nil
341+
}
342+
343343
if errors.Is(err, io.EOF) || isConnectionReset(err) {
344-
if c.IsDone() {
345-
return nil
346-
}
344+
return err
347345
}
348346
return err
349347
}
350348

351-
// IsDone returns true if the protocol is in the Done state
352-
func (c *ChainSync) IsDone() bool {
353-
c.stateMutex.Lock()
354-
defer c.stateMutex.Unlock()
355-
return c.currentState.Id == stateDone.Id
356-
}
357-
358349
func isConnectionReset(err error) bool {
359350
return strings.Contains(err.Error(), "connection reset") ||
360351
strings.Contains(err.Error(), "broken pipe")

0 commit comments

Comments
 (0)