Skip to content
30 changes: 30 additions & 0 deletions protocol/blockfetch/blockfetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
package blockfetch

import (
"errors"
"io"
"strings"
"time"

"github.com/blinklabs-io/gouroboros/connection"
Expand Down Expand Up @@ -118,6 +121,33 @@ func New(protoOptions protocol.ProtocolOptions, cfg *Config) *BlockFetch {
return b
}

func (b *BlockFetch) IsDone() bool {
if b.Client != nil && b.Client.IsDone() {
return true
}
if b.Server != nil && b.Server.IsDone() {
return true
}
return false
}

func (b *BlockFetch) HandleConnectionError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, io.EOF) || isConnectionReset(err) {
if b.IsDone() {
return nil
}
}
return err
}

func isConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe")
}

type BlockFetchOptionFunc func(*Config)

func NewConfig(options ...BlockFetchOptionFunc) Config {
Expand Down
171 changes: 171 additions & 0 deletions protocol/blockfetch/blockfetch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright 2025 Blink Labs Software
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package blockfetch

import (
"errors"
"io"
"log/slog"
"net"
"testing"
"time"

"github.com/blinklabs-io/gouroboros/connection"
"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/muxer"
"github.com/blinklabs-io/gouroboros/protocol"
"github.com/blinklabs-io/gouroboros/protocol/common"
"github.com/stretchr/testify/assert"
)

// testAddr implements net.Addr for testing
type testAddr struct{}

func (a testAddr) Network() string { return "test" }
func (a testAddr) String() string { return "test-addr" }

// testConn implements net.Conn for testing with buffered writes
type testConn struct {
writeChan chan []byte
}

func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil }
func (c *testConn) Write(b []byte) (n int, err error) {
c.writeChan <- b
return len(b), nil
}
func (c *testConn) Close() error { return nil }
func (c *testConn) LocalAddr() net.Addr { return testAddr{} }
func (c *testConn) RemoteAddr() net.Addr { return testAddr{} }
func (c *testConn) SetDeadline(t time.Time) error { return nil }
func (c *testConn) SetReadDeadline(t time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(t time.Time) error { return nil }

func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions {
mux := muxer.New(conn)
return protocol.ProtocolOptions{
ConnectionId: connection.ConnectionId{
LocalAddr: testAddr{},
RemoteAddr: testAddr{},
},
Muxer: mux,
Logger: slog.Default(),
}
}

func TestNewBlockFetch(t *testing.T) {
conn := &testConn{writeChan: make(chan []byte, 1)}
cfg := NewConfig()
bf := New(getTestProtocolOptions(conn), &cfg)
assert.NotNil(t, bf.Client)
assert.NotNil(t, bf.Server)
}

func TestConfigOptions(t *testing.T) {
t.Run("Default config", func(t *testing.T) {
cfg := NewConfig()
assert.Equal(t, 5*time.Second, cfg.BatchStartTimeout)
assert.Equal(t, 60*time.Second, cfg.BlockTimeout)
})

t.Run("Custom config", func(t *testing.T) {
cfg := NewConfig(
WithBatchStartTimeout(10*time.Second),
WithBlockTimeout(30*time.Second),
WithRecvQueueSize(100),
)
assert.Equal(t, 10*time.Second, cfg.BatchStartTimeout)
assert.Equal(t, 30*time.Second, cfg.BlockTimeout)
assert.Equal(t, 100, cfg.RecvQueueSize)
})
}

func TestConnectionErrorHandling(t *testing.T) {
conn := &testConn{writeChan: make(chan []byte, 1)}
cfg := NewConfig()
bf := New(getTestProtocolOptions(conn), &cfg)

t.Run("Non-EOF error when not done", func(t *testing.T) {
err := bf.HandleConnectionError(errors.New("test error"))
assert.Error(t, err)
})

t.Run("EOF error when not done", func(t *testing.T) {
err := bf.HandleConnectionError(io.EOF)
assert.Error(t, err)
})

t.Run("Connection reset error", func(t *testing.T) {
err := bf.HandleConnectionError(errors.New("connection reset by peer"))
assert.Error(t, err)
})
}

func TestCallbackRegistration(t *testing.T) {
conn := &testConn{writeChan: make(chan []byte, 1)}

t.Run("Block callback registration", func(t *testing.T) {
blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error {
return nil
}
cfg := NewConfig(WithBlockFunc(blockFunc))
client := NewClient(getTestProtocolOptions(conn), &cfg)
assert.NotNil(t, client)
assert.NotNil(t, cfg.BlockFunc)
})

t.Run("RequestRange callback registration", func(t *testing.T) {
requestRangeFunc := func(ctx CallbackContext, start, end common.Point) error {
return nil
}
cfg := NewConfig(WithRequestRangeFunc(requestRangeFunc))
server := NewServer(getTestProtocolOptions(conn), &cfg)
assert.NotNil(t, server)
assert.NotNil(t, cfg.RequestRangeFunc)
})
}

func TestClientMessageSending(t *testing.T) {
conn := &testConn{writeChan: make(chan []byte, 1)}
cfg := NewConfig()
client := NewClient(getTestProtocolOptions(conn), &cfg)

t.Run("Client can send messages", func(t *testing.T) {
// Start the client protocol
client.Start()

// Send a done message
err := client.Protocol.SendMessage(NewMsgClientDone())
assert.NoError(t, err)

// Verify message was written to connection
select {
case <-conn.writeChan:
// Message was sent successfully
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout waiting for message send")
}
})
}

func TestServerMessageHandling(t *testing.T) {
conn := &testConn{writeChan: make(chan []byte, 1)}
cfg := NewConfig()
server := NewServer(getTestProtocolOptions(conn), &cfg)

t.Run("Server can be started", func(t *testing.T) {
server.Start()

})
}
23 changes: 22 additions & 1 deletion protocol/blockfetch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Client struct {
blockUseCallback bool
onceStart sync.Once
onceStop sync.Once
currentState protocol.State
stateMutex sync.Mutex
}

func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
Expand All @@ -46,6 +48,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
config: cfg,
blockChan: make(chan ledger.Block),
startBatchResultChan: make(chan error),
currentState: StateIdle,
}
c.callbackContext = CallbackContext{
Client: c,
Expand Down Expand Up @@ -82,6 +85,18 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
return c
}

func (c *Client) IsDone() bool {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
return c.currentState.Id == StateDone.Id
}

func (c *Client) setState(newState protocol.State) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
c.currentState = newState
}

func (c *Client) Start() {
c.onceStart.Do(func() {
c.Protocol.Logger().
Expand Down Expand Up @@ -110,7 +125,11 @@ func (c *Client) Stop() error {
"connection_id", c.callbackContext.ConnectionId.String(),
)
msg := NewMsgClientDone()
err = c.SendMessage(msg)
if sendErr := c.SendMessage(msg); sendErr != nil {
err = sendErr
return
}
c.setState(StateDone)
})
return err
}
Expand Down Expand Up @@ -196,6 +215,8 @@ func (c *Client) messageHandler(msg protocol.Message) error {
err = c.handleBlock(msg)
case MessageTypeBatchDone:
err = c.handleBatchDone()
case MessageTypeClientDone:
c.setState(StateDone)
default:
err = fmt.Errorf(
"%s: received unexpected message type %d",
Expand Down
17 changes: 17 additions & 0 deletions protocol/blockfetch/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package blockfetch
import (
"errors"
"fmt"
"sync"

"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/protocol"
Expand All @@ -27,13 +28,16 @@ type Server struct {
config *Config
callbackContext CallbackContext
protoOptions protocol.ProtocolOptions
currentState protocol.State
stateMutex sync.Mutex
}

func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
s := &Server{
config: cfg,
// Save this for re-use later
protoOptions: protoOptions,
currentState: StateIdle,
}
s.callbackContext = CallbackContext{
Server: s,
Expand All @@ -43,6 +47,18 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server {
return s
}

func (s *Server) IsDone() bool {
s.stateMutex.Lock()
defer s.stateMutex.Unlock()
return s.currentState.Id == StateDone.Id
}

func (s *Server) setState(newState protocol.State) {
s.stateMutex.Lock()
defer s.stateMutex.Unlock()
s.currentState = newState
}

func (s *Server) initProtocol() {
protoConfig := protocol.ProtocolConfig{
Name: ProtocolName,
Expand Down Expand Up @@ -126,6 +142,7 @@ func (s *Server) messageHandler(msg protocol.Message) error {
case MessageTypeRequestRange:
err = s.handleRequestRange(msg)
case MessageTypeClientDone:
s.setState(StateDone)
err = s.handleClientDone()
default:
err = fmt.Errorf(
Expand Down
34 changes: 32 additions & 2 deletions protocol/chainsync/chainsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package chainsync

import (
"errors"
"io"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -192,8 +195,10 @@ var PipelineIsNotEmpty = func(context any, msg protocol.Message) bool {

// ChainSync is a wrapper object that holds the client and server instances
type ChainSync struct {
Client *Client
Server *Server
Client *Client
Server *Server
stateMutex sync.Mutex
currentState protocol.State
}

// Config is used to configure the ChainSync protocol instance
Expand Down Expand Up @@ -329,3 +334,28 @@ func WithRecvQueueSize(size int) ChainSyncOptionFunc {
c.RecvQueueSize = size
}
}

// HandleConnectionError handles connection errors and determines if they should be ignored
func (c *ChainSync) HandleConnectionError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, io.EOF) || isConnectionReset(err) {
if c.IsDone() {
return nil
}
}
return err
}

// IsDone returns true if the protocol is in the Done state
func (c *ChainSync) IsDone() bool {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
return c.currentState.Id == stateDone.Id
}

func isConnectionReset(err error) bool {
return strings.Contains(err.Error(), "connection reset") ||
strings.Contains(err.Error(), "broken pipe")
}
Loading
Loading