diff --git a/core/network/conn.go b/core/network/conn.go index f82da29e61..68cb8a58ae 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -101,6 +101,11 @@ type Conn interface { // // You very likely do not need to use this method. As(target any) bool + + // Context returns a context that is cancelled when the connection is closed. + // This can be used to clean up resources associated with the connection + // and to signal early cancellation of work that depends on the connection. + Context() context.Context } // ConnectionState holds information about the connection. diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 216e0d9b18..92310a8674 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -6,6 +6,7 @@ import ( "strconv" "sync" "sync/atomic" + "time" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -193,3 +194,25 @@ func (c *conn) Scope() network.ConnScope { func (c *conn) CloseWithError(_ network.ConnErrorCode) error { return c.Close() } + +// Context returns a context that is cancelled when the connection is closed. +func (c *conn) Context() context.Context { + // For mock connections, we return a context that is cancelled when the connection is closed + // This is a simplified implementation for testing purposes + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // Wait for the connection to be closed by checking periodically + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if c.IsClosed() { + cancel() + return + } + } + } + }() + return ctx +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index b97fe669a9..42cc0efffd 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -374,11 +374,14 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, isLimited := stat.Limited // Wrap and register the connection. + ctx, cancel := context.WithCancel(context.Background()) c := &Conn{ - conn: tc, - swarm: s, - stat: stat, - id: s.nextConnID.Add(1), + conn: tc, + swarm: s, + stat: stat, + id: s.nextConnID.Add(1), + ctx: ctx, + cancel: cancel, } // we ONLY check upgraded connections here so we can send them a Disconnect message. diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 7bcda66b9f..acae81373c 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -38,6 +38,10 @@ type Conn struct { } stat network.ConnStats + + // Context and cancel function for connection lifecycle management + ctx context.Context + cancel context.CancelFunc } var _ network.Conn = &Conn{} @@ -78,6 +82,9 @@ func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error { func (c *Conn) doClose(errCode network.ConnErrorCode) { c.swarm.removeConn(c) + // Cancel the context to signal that the connection is closed + c.cancel() + // Prevent new streams from opening. c.streams.Lock() streams := c.streams.m @@ -297,3 +304,8 @@ func (c *Conn) GetStreams() []network.Stream { func (c *Conn) Scope() network.ConnScope { return c.conn.Scope() } + +// Context returns a context that is cancelled when the connection is closed. +func (c *Conn) Context() context.Context { + return c.ctx +} diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index eca721c2b4..07fd200d8a 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -626,3 +626,251 @@ func TestAddCertHashes(t *testing.T) { } } } + +func TestConnContext(t *testing.T) { + // Test that the context is cancelled when the connection is closed + t.Run("ContextCancelledOnClose", func(t *testing.T) { + s1 := GenSwarm(t, OptDisableReuseport) + s2 := GenSwarm(t, OptDisableReuseport) + defer s1.Close() + defer s2.Close() + + // Connect the swarms + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + // Get the connection + conns := s1.ConnsToPeer(s2.LocalPeer()) + require.Len(t, conns, 1) + conn := conns[0] + + // Get the context + ctx := conn.Context() + require.NotNil(t, ctx) + + // Context should not be cancelled initially + select { + case <-ctx.Done(): + t.Fatal("context should not be cancelled initially") + default: + } + + // Close the connection + err = conn.Close() + require.NoError(t, err) + + // Context should be cancelled now + select { + case <-ctx.Done(): + // Expected + case <-time.After(time.Second): + t.Fatal("context should be cancelled after connection close") + } + + // Verify the context error + require.Error(t, ctx.Err()) + require.Equal(t, context.Canceled, ctx.Err()) + }) + + // Test that the context is cancelled when the connection is closed with error + t.Run("ContextCancelledOnCloseWithError", func(t *testing.T) { + s1 := GenSwarm(t, OptDisableReuseport) + s2 := GenSwarm(t, OptDisableReuseport) + defer s1.Close() + defer s2.Close() + + // Connect the swarms + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + // Get the connection + conns := s1.ConnsToPeer(s2.LocalPeer()) + require.Len(t, conns, 1) + conn := conns[0] + + // Get the context + ctx := conn.Context() + require.NotNil(t, ctx) + + // Context should not be cancelled initially + select { + case <-ctx.Done(): + t.Fatal("context should not be cancelled initially") + default: + } + + // Close the connection with error + err = conn.CloseWithError(network.ConnShutdown) + require.NoError(t, err) + + // Context should be cancelled now + select { + case <-ctx.Done(): + // Expected + case <-time.After(time.Second): + t.Fatal("context should be cancelled after connection close with error") + } + + // Verify the context error + require.Error(t, ctx.Err()) + require.Equal(t, context.Canceled, ctx.Err()) + }) + + // Test that the context can be used with context.AfterFunc + t.Run("ContextAfterFunc", func(t *testing.T) { + s1 := GenSwarm(t, OptDisableReuseport) + s2 := GenSwarm(t, OptDisableReuseport) + defer s1.Close() + defer s2.Close() + + // Connect the swarms + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + // Get the connection + conns := s1.ConnsToPeer(s2.LocalPeer()) + require.Len(t, conns, 1) + conn := conns[0] + + // Get the context + ctx := conn.Context() + require.NotNil(t, ctx) + + // Use context.AfterFunc to clean up resources + var cleanupCalled bool + var cleanupMutex sync.Mutex + context.AfterFunc(ctx, func() { + cleanupMutex.Lock() + cleanupCalled = true + cleanupMutex.Unlock() + }) + + // Close the connection + err = conn.Close() + require.NoError(t, err) + + // Wait for the cleanup function to be called + require.Eventually(t, func() bool { + cleanupMutex.Lock() + defer cleanupMutex.Unlock() + return cleanupCalled + }, time.Second, 10*time.Millisecond, "cleanup function should be called") + + // Verify the context error + require.Error(t, ctx.Err()) + require.Equal(t, context.Canceled, ctx.Err()) + }) + + // Test that multiple contexts from the same connection are all cancelled + t.Run("MultipleContextsCancelled", func(t *testing.T) { + s1 := GenSwarm(t, OptDisableReuseport) + s2 := GenSwarm(t, OptDisableReuseport) + defer s1.Close() + defer s2.Close() + + // Connect the swarms + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + // Get the connection + conns := s1.ConnsToPeer(s2.LocalPeer()) + require.Len(t, conns, 1) + conn := conns[0] + + // Get multiple contexts + ctx1 := conn.Context() + ctx2 := conn.Context() + require.NotNil(t, ctx1) + require.NotNil(t, ctx2) + + // Both contexts should be the same instance + require.Equal(t, ctx1, ctx2) + + // Contexts should not be cancelled initially + select { + case <-ctx1.Done(): + t.Fatal("context1 should not be cancelled initially") + case <-ctx2.Done(): + t.Fatal("context2 should not be cancelled initially") + default: + } + + // Close the connection + err = conn.Close() + require.NoError(t, err) + + // Both contexts should be cancelled now + select { + case <-ctx1.Done(): + // Expected + case <-time.After(time.Second): + t.Fatal("context1 should be cancelled after connection close") + } + + select { + case <-ctx2.Done(): + // Expected + case <-time.After(time.Second): + t.Fatal("context2 should be cancelled after connection close") + } + + // Verify both context errors + require.Error(t, ctx1.Err()) + require.Equal(t, context.Canceled, ctx1.Err()) + require.Error(t, ctx2.Err()) + require.Equal(t, context.Canceled, ctx2.Err()) + }) + + // Test that the context is cancelled when the remote peer closes the connection + t.Run("ContextCancelledOnRemoteClose", func(t *testing.T) { + s1 := GenSwarm(t, OptDisableReuseport) + s2 := GenSwarm(t, OptDisableReuseport) + defer s1.Close() + defer s2.Close() + + // Connect the swarms + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + _, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + + // Get the connection + conns := s1.ConnsToPeer(s2.LocalPeer()) + require.Len(t, conns, 1) + conn := conns[0] + + // Get the context + ctx := conn.Context() + require.NotNil(t, ctx) + + // Context should not be cancelled initially + select { + case <-ctx.Done(): + t.Fatal("context should not be cancelled initially") + default: + } + + // Close the remote swarm (simulating remote peer closing) + s2.Close() + + // Wait for the connection to be closed + require.Eventually(t, func() bool { + return conn.IsClosed() + }, time.Second, 10*time.Millisecond, "connection should be closed") + + // Context should be cancelled now + select { + case <-ctx.Done(): + // Expected + case <-time.After(time.Second): + t.Fatal("context should be cancelled after remote close") + } + + // Verify the context error + require.Error(t, ctx.Err()) + require.Equal(t, context.Canceled, ctx.Err()) + }) +}