diff --git a/connection.go b/connection.go index 387a4ba5..20fe1ace 100644 --- a/connection.go +++ b/connection.go @@ -29,7 +29,10 @@ import ( "io" "log/slog" "net" + "os" + "strings" "sync" + "syscall" "time" "github.com/blinklabs-io/gouroboros/connection" @@ -250,6 +253,100 @@ func (c *Connection) shutdown() { close(c.errorChan) } +// isConnectionReset checks if an error is a connection reset error using proper error type checking +func (c *Connection) isConnectionReset(err error) bool { + if errors.Is(err, io.EOF) { + return true + } + + // Check for connection reset errors using proper error type checking + var opErr *net.OpError + if errors.As(err, &opErr) { + if syscallErr, ok := opErr.Err.(*os.SyscallError); ok { + if errno, ok := syscallErr.Err.(syscall.Errno); ok { + // Check for connection reset (ECONNRESET) or broken pipe (EPIPE) + return errno == syscall.ECONNRESET || errno == syscall.EPIPE + } + } + // Also check for string-based errors as fallback for edge cases + errStr := opErr.Err.Error() + return strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "broken pipe") + } + + return false +} + +// checkProtocolsDone checks if the protocols are explicitly stopped by the client - treat as normal connection closure +func (c *Connection) checkProtocolsDone() bool { + // Check chain-sync protocol + if c.chainSync != nil { + if (c.chainSync.Client != nil && !c.chainSync.Client.IsDone()) || + (c.chainSync.Server != nil && !c.chainSync.Server.IsDone()) { + return false + } + } + + // Check block-fetch protocol + if c.blockFetch != nil { + if (c.blockFetch.Client != nil && !c.blockFetch.Client.IsDone()) || + (c.blockFetch.Server != nil && !c.blockFetch.Server.IsDone()) { + return false + } + } + + // Check tx-submission protocol + if c.txSubmission != nil { + if (c.txSubmission.Client != nil && !c.txSubmission.Client.IsDone()) || + (c.txSubmission.Server != nil && !c.txSubmission.Server.IsDone()) { + return false + } + } + + // Check local-state-query protocol + if c.localStateQuery != nil { + if (c.localStateQuery.Client != nil && !c.localStateQuery.Client.IsDone()) || + (c.localStateQuery.Server != nil && !c.localStateQuery.Server.IsDone()) { + return false + } + } + + // Check local-tx-monitor protocol + if c.localTxMonitor != nil { + if (c.localTxMonitor.Client != nil && !c.localTxMonitor.Client.IsDone()) || + (c.localTxMonitor.Server != nil && !c.localTxMonitor.Server.IsDone()) { + return false + } + } + + // Check local-tx-submission protocol + if c.localTxSubmission != nil { + if (c.localTxSubmission.Client != nil && !c.localTxSubmission.Client.IsDone()) || + (c.localTxSubmission.Server != nil && !c.localTxSubmission.Server.IsDone()) { + return false + } + } + + return true +} + +// handleConnectionError handles connection-level errors centrally +func (c *Connection) handleConnectionError(err error) error { + if err == nil { + return nil + } + + if c.checkProtocolsDone() { + return nil + } + + if errors.Is(err, io.EOF) || c.isConnectionReset(err) { + return err + } + + return err +} + // setupConnection establishes the muxer, configures and starts the handshake process, and initializes // the appropriate mini-protocols func (c *Connection) setupConnection() error { @@ -285,16 +382,20 @@ func (c *Connection) setupConnection() error { if !ok { return } - var connErr *muxer.ConnectionClosedError - if errors.As(err, &connErr) { - // Pass through ConnectionClosedError from muxer - c.errorChan <- err - } else { - // Wrap error message to denote it comes from the muxer - c.errorChan <- fmt.Errorf("muxer error: %w", err) + + // Use centralized connection error handling + if handledErr := c.handleConnectionError(err); handledErr != nil { + var connErr *muxer.ConnectionClosedError + if errors.As(handledErr, &connErr) { + // Pass through ConnectionClosedError from muxer + c.errorChan <- handledErr + } else { + // Wrap error message to denote it comes from the muxer + c.errorChan <- fmt.Errorf("muxer error: %w", handledErr) + } + // Close connection on muxer errors + c.Close() } - // Close connection on muxer errors - c.Close() } }() protoOptions := protocol.ProtocolOptions{ diff --git a/connection_test.go b/connection_test.go index f904b6ad..2bd2d840 100644 --- a/connection_test.go +++ b/connection_test.go @@ -15,70 +15,266 @@ package ouroboros_test import ( - "fmt" "testing" "time" ouroboros "github.com/blinklabs-io/gouroboros" - "github.com/blinklabs-io/ouroboros-mock" + "github.com/blinklabs-io/gouroboros/protocol/chainsync" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" "go.uber.org/goleak" ) -// Ensure that we don't panic when closing the Connection object after a failed Dial() call -func TestDialFailClose(t *testing.T) { +// TestErrorHandlingWithActiveProtocols tests that connection errors are propagated +// when protocols are active, and ignored when protocols are stopped +func TestErrorHandlingWithActiveProtocols(t *testing.T) { defer goleak.VerifyNone(t) - oConn, err := ouroboros.New() - if err != nil { - t.Fatalf("unexpected error when creating Connection object: %s", err) - } - err = oConn.Dial("unix", "/path/does/not/exist") - if err == nil { - t.Fatalf("did not get expected failure on Dial()") - } - // Close connection - oConn.Close() + + t.Run("ErrorsPropagatedWhenProtocolsActive", func(t *testing.T) { + // Create a mock connection that will complete handshake + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Wait for handshake to complete by checking if protocols are initialized + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Client != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + if chainSyncProtocol == nil || chainSyncProtocol.Client == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } + + // Start the chain sync protocol to make it active + chainSyncProtocol.Client.Start() + + // Wait a bit for protocol to start + time.Sleep(100 * time.Millisecond) + + // Close the mock connection to generate a connection error + mockConn.Close() + + // We should receive a connection error since protocols are active + select { + case err := <-oConn.ErrorChan(): + if err == nil { + t.Fatal("expected connection error, got nil") + } + t.Logf("Received connection error (expected with active protocols): %s", err) + case <-time.After(2 * time.Second): + t.Error("timed out waiting for connection error - error should be propagated when protocols are active") + } + + oConn.Close() + }) + + t.Run("ErrorsIgnoredWhenProtocolsStopped", func(t *testing.T) { + // Create a mock connection + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // Wait for handshake to complete + var chainSyncProtocol *chainsync.ChainSync + for i := 0; i < 100; i++ { + chainSyncProtocol = oConn.ChainSync() + if chainSyncProtocol != nil && chainSyncProtocol.Client != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + if chainSyncProtocol == nil || chainSyncProtocol.Client == nil { + oConn.Close() + t.Fatal("chain sync protocol not initialized") + } + + // Start and then immediately stop the protocol + chainSyncProtocol.Client.Start() + time.Sleep(50 * time.Millisecond) + + // Stop the protocol explicitly + if err := chainSyncProtocol.Client.Stop(); err != nil { + t.Fatalf("failed to stop chain sync: %s", err) + } + + // Wait for protocol to be done + select { + case <-chainSyncProtocol.Client.DoneChan(): + // Protocol is stopped + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for protocol to stop") + } + + // Now close the mock connection to generate an error + mockConn.Close() + select { + case err := <-oConn.ErrorChan(): + t.Logf("Received error during shutdown: %s", err) + case <-time.After(500 * time.Millisecond): + t.Log("No connection error received (expected when protocols are stopped)") + } + + oConn.Close() + }) } -func TestDoubleClose(t *testing.T) { +// TestErrorHandlingWithMultipleProtocols tests error handling with multiple active protocols +func TestErrorHandlingWithMultipleProtocols(t *testing.T) { defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( ouroboros_mock.ProtocolRoleClient, []ouroboros_mock.ConversationEntry{ ouroboros_mock.ConversationEntryHandshakeRequestGeneric, - ouroboros_mock.ConversationEntryHandshakeNtCResponse, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, }, ) + oConn, err := ouroboros.New( ouroboros.WithConnection(mockConn), ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), ) if err != nil { t.Fatalf("unexpected error when creating Connection object: %s", err) } - // Async error handler - go func() { - err, ok := <-oConn.ErrorChan() - if !ok { - return - } - // We can't call t.Fatalf() from a different Goroutine, so we panic instead - panic(fmt.Sprintf("unexpected Ouroboros connection error: %s", err)) - }() - // Close connection - if err := oConn.Close(); err != nil { - t.Fatalf("unexpected error when closing Connection object: %s", err) + + // Wait for handshake to complete + time.Sleep(100 * time.Millisecond) + + // Start multiple protocols + chainSync := oConn.ChainSync() + blockFetch := oConn.BlockFetch() + txSubmission := oConn.TxSubmission() + + if chainSync != nil && chainSync.Client != nil { + chainSync.Client.Start() + } + if blockFetch != nil && blockFetch.Client != nil { + blockFetch.Client.Start() + } + if txSubmission != nil && txSubmission.Client != nil { + txSubmission.Client.Start() + } + + // Wait for protocols to start + time.Sleep(100 * time.Millisecond) + + // Close connection to generate error + mockConn.Close() + + // Should receive error since protocols are active + select { + case err := <-oConn.ErrorChan(): + if err == nil { + t.Fatal("expected connection error, got nil") + } + t.Logf("Received connection error with multiple active protocols: %s", err) + case <-time.After(2 * time.Second): + t.Error("timed out waiting for connection error") } - // Close connection again - if err := oConn.Close(); err != nil { - t.Fatalf( - "unexpected error when closing Connection object again: %s", - err, + + oConn.Close() +} + +// TestBasicErrorHandling tests basic error handling scenarios +func TestBasicErrorHandling(t *testing.T) { + defer goleak.VerifyNone(t) + + t.Run("DialFailure", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + err = oConn.Dial("tcp", "invalid-hostname:9999") + if err == nil { + t.Fatal("expected dial error, got nil") + } + + oConn.Close() + }) + + t.Run("DoubleClose", func(t *testing.T) { + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + // First close + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on first close: %s", err) + } + + // Second close should also work + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error on second close: %s", err) + } + }) +} + +// TestErrorChannelBehavior tests basic error channel behavior +func TestErrorChannelBehavior(t *testing.T) { + defer goleak.VerifyNone(t) + + oConn, err := ouroboros.New( + ouroboros.WithNetworkMagic(764824073), + ) + if err != nil { + t.Fatalf("unexpected error when creating Connection object: %s", err) + } + + errorChan := oConn.ErrorChan() + if errorChan == nil { + t.Fatal("error channel should not be nil") } - // Wait for connection shutdown + select { - case <-oConn.ErrorChan(): - case <-time.After(10 * time.Second): - t.Errorf("did not shutdown within timeout") + case err, ok := <-errorChan: + if ok { + t.Logf("Error channel contained: %s", err) + } else { + t.Error("Error channel should not be closed initially") + } + default: + // Expected - channel is empty but open } + + oConn.Close() } diff --git a/protocol/blockfetch/blockfetch_test.go b/protocol/blockfetch/blockfetch_test.go new file mode 100644 index 00000000..2a5cfbba --- /dev/null +++ b/protocol/blockfetch/blockfetch_test.go @@ -0,0 +1,166 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package blockfetch + +import ( + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/ledger" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +// testAddr implements net.Addr for testing +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +// testConn implements net.Conn for testing with buffered writes +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewBlockFetch(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + bf := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, bf.Client) + assert.NotNil(t, bf.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 60*time.Second, cfg.BlockTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithBatchStartTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithRecvQueueSize(100), + ) + assert.Equal(t, 10*time.Second, cfg.BatchStartTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 100, cfg.RecvQueueSize) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("Block callback registration", func(t *testing.T) { + blockFunc := func(ctx CallbackContext, slot uint, block ledger.Block) error { + return nil + } + cfg := NewConfig(WithBlockFunc(blockFunc)) + client := NewClient(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.BlockFunc) + }) + + t.Run("RequestRange callback registration", func(t *testing.T) { + requestRangeFunc := func(ctx CallbackContext, start, end common.Point) error { + return nil + } + cfg := NewConfig(WithRequestRangeFunc(requestRangeFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestRangeFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + // Send a done message + err := client.SendMessage(NewMsgClientDone()) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index 23f1091a..be33fdf0 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -109,8 +109,12 @@ func (c *Client) Stop() error { "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) - msg := NewMsgClientDone() - err = c.SendMessage(msg) + if !c.IsDone() { + msg := NewMsgClientDone() + if err = c.SendMessage(msg); err != nil { + return + } + } }) return err } diff --git a/protocol/chainsync/chainsync_test.go b/protocol/chainsync/chainsync_test.go new file mode 100644 index 00000000..d660717d --- /dev/null +++ b/protocol/chainsync/chainsync_test.go @@ -0,0 +1,167 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chainsync + +import ( + "io" + "log/slog" + "net" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/common" + "github.com/stretchr/testify/assert" +) + +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Write(b []byte) (n int, err error) { + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} +func (c *testConn) Close() error { + if !c.closed { + close(c.closeChan) + c.closed = true + } + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + go mux.Start() + go func() { + <-conn.(*testConn).closeChan + mux.Stop() + }() + + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + Mode: protocol.ProtocolModeNodeToClient, + } +} + +func TestNewChainSync(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + cs := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, cs.Client) + assert.NotNil(t, cs.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 5*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 300*time.Second, cfg.BlockTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIntersectTimeout(10*time.Second), + WithBlockTimeout(30*time.Second), + WithPipelineLimit(10), + WithRecvQueueSize(100), + ) + assert.Equal(t, 10*time.Second, cfg.IntersectTimeout) + assert.Equal(t, 30*time.Second, cfg.BlockTimeout) + assert.Equal(t, 10, cfg.PipelineLimit) + assert.Equal(t, 100, cfg.RecvQueueSize) + }) +} + +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("RollForward callback registration", func(t *testing.T) { + rollForwardFunc := func(ctx CallbackContext, blockType uint, blockData any, tip Tip) error { + return nil + } + cfg := NewConfig(WithRollForwardFunc(rollForwardFunc)) + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, client) + assert.NotNil(t, cfg.RollForwardFunc) + }) + + t.Run("FindIntersect callback registration", func(t *testing.T) { + findIntersectFunc := func(ctx CallbackContext, points []common.Point) (common.Point, Tip, error) { + return common.Point{}, Tip{}, nil + } + cfg := NewConfig(WithFindIntersectFunc(findIntersectFunc)) + server := NewServer(nil, getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.FindIntersectFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(nil, getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + // Send a done message + err := client.SendMessage(NewMsgDone()) + assert.NoError(t, err) + + // Verify message was written to connection + select { + case <-conn.writeChan: + // Message was sent successfully + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for message send") + } + }) +} diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 6a15fab2..27a8c2aa 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -147,9 +147,11 @@ func (c *Client) Stop() error { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if !c.IsDone() { + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } } }) return err diff --git a/protocol/protocol.go b/protocol/protocol.go index 4f1d524c..1f3ace54 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -38,6 +38,7 @@ const DefaultRecvQueueSize = 50 // Protocol implements the base functionality of an Ouroboros mini-protocol type Protocol struct { config ProtocolConfig + currentState State doneChan chan struct{} muxerSendChan chan *muxer.Segment muxerRecvChan chan *muxer.Segment @@ -51,6 +52,7 @@ type Protocol struct { stateTransitionChan chan<- protocolStateTransition onceStart sync.Once onceStop sync.Once + stateMutex sync.RWMutex } // ProtocolConfig provides the configuration for Protocol @@ -102,8 +104,9 @@ type ProtocolOptions struct { } type protocolStateTransition struct { - msg Message - errorChan chan<- error + msg Message + errorChan chan<- error + stateRespChan chan<- State } // MessageHandlerFunc represents a function that handles an incoming message @@ -126,6 +129,36 @@ func New(config ProtocolConfig) *Protocol { return p } +// CurrentState returns the current protocol state +func (p *Protocol) CurrentState() State { + p.stateMutex.RLock() + defer p.stateMutex.RUnlock() + return p.currentState +} + +// IsDone checks if the protocol is in a done/completed state +func (p *Protocol) IsDone() bool { + currentState := p.CurrentState() + // return true if current state has AgencyNone + if entry, exists := p.config.StateMap[currentState]; exists { + if entry.Agency == AgencyNone { + return true + } + } + // return true if current state is the initial state + return currentState == p.config.InitialState +} + +// GetDoneState returns the done state from the state map +func (s StateMap) GetDoneState() State { + for state, entry := range s { + if entry.Agency == AgencyNone { + return state + } + } + return State{} +} + // Start initializes the mini-protocol func (p *Protocol) Start() { p.onceStart.Do(func() { @@ -446,7 +479,6 @@ func (p *Protocol) recvLoop() { } func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { - var currentState State var transitionTimer *time.Timer setState := func(s State) { @@ -456,11 +488,18 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { } transitionTimer = nil - // Set the new state - currentState = s + // Set the new state with proper locking + p.stateMutex.Lock() + p.currentState = s + p.stateMutex.Unlock() // Mark protocol as ready to send/receive based on role and agency of the new state - switch p.config.StateMap[currentState].Agency { + stateEntry, exists := p.config.StateMap[s] + if !exists { + return + } + + switch stateEntry.Agency { case AgencyNone: return case AgencyClient: @@ -496,12 +535,11 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { } // Set timeout for state transition - if p.config.StateMap[currentState].Timeout > 0 { - transitionTimer = time.NewTimer( - p.config.StateMap[currentState].Timeout, - ) + if stateEntry.Timeout > 0 { + transitionTimer = time.NewTimer(stateEntry.Timeout) } } + getTimerChan := func() <-chan time.Time { if transitionTimer == nil { return nil @@ -509,11 +547,26 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { return transitionTimer.C } + // Initialize current state setState(p.config.InitialState) for { select { case t := <-ch: + if t.msg == nil && t.stateRespChan != nil { + // Handle state request - use the field instead of local variable + p.stateMutex.RLock() + currentState := p.currentState + p.stateMutex.RUnlock() + t.stateRespChan <- currentState + continue + } + + // Get current state for transition logic + p.stateMutex.RLock() + currentState := p.currentState + p.stateMutex.RUnlock() + nextState, err := p.nextState(currentState, t.msg) if err != nil { t.errorChan <- fmt.Errorf( @@ -538,7 +591,7 @@ func (p *Protocol) stateLoop(ch <-chan protocolStateTransition) { fmt.Errorf( "%s: timeout waiting on transition from protocol state %s", p.config.Name, - currentState, + p.CurrentState(), ), ) @@ -574,7 +627,7 @@ func (p *Protocol) nextState(currentState State, msg Message) (State, error) { func (p *Protocol) transitionState(msg Message) error { errorChan := make(chan error, 1) - p.stateTransitionChan <- protocolStateTransition{msg, errorChan} + p.stateTransitionChan <- protocolStateTransition{msg, errorChan, nil} return <-errorChan } diff --git a/protocol/txsubmission/client.go b/protocol/txsubmission/client.go index d96d7641..a956329b 100644 --- a/protocol/txsubmission/client.go +++ b/protocol/txsubmission/client.go @@ -86,6 +86,7 @@ func (c *Client) Init() { func (c *Client) messageHandler(msg protocol.Message) error { c.Protocol.Logger(). Debug(fmt.Sprintf("%s: client message for %+v", ProtocolName, c.callbackContext.ConnectionId.RemoteAddr)) + var err error switch msg.Type() { case MessageTypeRequestTxIds: diff --git a/protocol/txsubmission/txsubmission_test.go b/protocol/txsubmission/txsubmission_test.go new file mode 100644 index 00000000..eb7e5e60 --- /dev/null +++ b/protocol/txsubmission/txsubmission_test.go @@ -0,0 +1,181 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package txsubmission + +import ( + "io" + "log/slog" + "net" + "sync" + "testing" + "time" + + "github.com/blinklabs-io/gouroboros/connection" + "github.com/blinklabs-io/gouroboros/muxer" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testAddr struct{} + +func (a testAddr) Network() string { return "test" } +func (a testAddr) String() string { return "test-addr" } + +type testConn struct { + writeChan chan []byte + closed bool + closeChan chan struct{} + closeOnce sync.Once + mu sync.Mutex +} + +func newTestConn() *testConn { + return &testConn{ + writeChan: make(chan []byte, 100), + closeChan: make(chan struct{}), + } +} + +func (c *testConn) Read(b []byte) (n int, err error) { return 0, nil } +func (c *testConn) Close() error { + c.closeOnce.Do(func() { + c.mu.Lock() + defer c.mu.Unlock() + close(c.closeChan) + c.closed = true + }) + return nil +} +func (c *testConn) LocalAddr() net.Addr { return testAddr{} } +func (c *testConn) RemoteAddr() net.Addr { return testAddr{} } +func (c *testConn) SetDeadline(t time.Time) error { return nil } +func (c *testConn) SetReadDeadline(t time.Time) error { return nil } +func (c *testConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *testConn) Write(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, io.EOF + } + select { + case c.writeChan <- b: + return len(b), nil + case <-c.closeChan: + return 0, io.EOF + } +} + +func getTestProtocolOptions(conn net.Conn) protocol.ProtocolOptions { + mux := muxer.New(conn) + return protocol.ProtocolOptions{ + ConnectionId: connection.ConnectionId{ + LocalAddr: testAddr{}, + RemoteAddr: testAddr{}, + }, + Muxer: mux, + Logger: slog.Default(), + } +} + +func TestNewTxSubmission(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + ts := New(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, ts.Client) + assert.NotNil(t, ts.Server) +} + +func TestConfigOptions(t *testing.T) { + t.Run("Default config", func(t *testing.T) { + cfg := NewConfig() + assert.Equal(t, 300*time.Second, cfg.IdleTimeout) + }) + + t.Run("Custom config", func(t *testing.T) { + cfg := NewConfig( + WithIdleTimeout(60*time.Second), + WithRequestTxIdsFunc(func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + }), + WithRequestTxsFunc(func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + }), + ) + assert.Equal(t, 60*time.Second, cfg.IdleTimeout) + assert.NotNil(t, cfg.RequestTxIdsFunc) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} +func TestCallbackRegistration(t *testing.T) { + conn := newTestConn() + defer conn.Close() + + t.Run("RequestTxIds callback registration", func(t *testing.T) { + requestTxIdsFunc := func(ctx CallbackContext, blocking bool, ack, req uint16) ([]TxIdAndSize, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxIdsFunc(requestTxIdsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxIdsFunc) + }) + + t.Run("RequestTxs callback registration", func(t *testing.T) { + requestTxsFunc := func(ctx CallbackContext, txIds []TxId) ([]TxBody, error) { + return nil, nil + } + cfg := NewConfig(WithRequestTxsFunc(requestTxsFunc)) + server := NewServer(getTestProtocolOptions(conn), &cfg) + assert.NotNil(t, server) + assert.NotNil(t, cfg.RequestTxsFunc) + }) +} + +func TestClientMessageSending(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + client := NewClient(getTestProtocolOptions(conn), &cfg) + + t.Run("Client can send messages", func(t *testing.T) { + client.Start() + defer client.Stop() + + err := client.SendMessage(NewMsgInit()) + require.NoError(t, err) + + select { + case msg := <-conn.writeChan: + assert.NotEmpty(t, msg, "expected message to be written") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for message send") + } + }) +} + +func TestServerMessageHandling(t *testing.T) { + conn := newTestConn() + defer conn.Close() + cfg := NewConfig() + server := NewServer(getTestProtocolOptions(conn), &cfg) + + t.Run("Server can be started", func(t *testing.T) { + server.Start() + defer server.Stop() + assert.NotNil(t, server) + }) +}