Skip to content

Commit 60cfd91

Browse files
authored
GODRIVER-2658 Better guard against nil pinned connections. (#1153)
1 parent c05cb8a commit 60cfd91

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,10 @@ func (c initConnection) SupportsStreaming() bool {
578578
}
579579

580580
// Connection implements the driver.Connection interface to allow reading and writing wire
581-
// messages and the driver.Expirable interface to allow expiring.
581+
// messages and the driver.Expirable interface to allow expiring. It wraps an underlying
582+
// topology.connection to make it more goroutine-safe and nil-safe.
582583
type Connection struct {
583-
*connection
584+
connection *connection
584585
refCount int
585586
cleanupPoolFn func()
586587

@@ -602,7 +603,7 @@ func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
602603
if c.connection == nil {
603604
return ErrConnectionClosed
604605
}
605-
return c.writeWireMessage(ctx, wm)
606+
return c.connection.writeWireMessage(ctx, wm)
606607
}
607608

608609
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
@@ -613,7 +614,7 @@ func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, e
613614
if c.connection == nil {
614615
return dst, ErrConnectionClosed
615616
}
616-
return c.readWireMessage(ctx, dst)
617+
return c.connection.readWireMessage(ctx, dst)
617618
}
618619

619620
// CompressWireMessage handles compressing the provided wire message using the underlying
@@ -656,7 +657,7 @@ func (c *Connection) Description() description.Server {
656657
if c.connection == nil {
657658
return description.Server{}
658659
}
659-
return c.desc
660+
return c.connection.desc
660661
}
661662

662663
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
@@ -679,12 +680,12 @@ func (c *Connection) Expire() error {
679680
return nil
680681
}
681682

682-
_ = c.close()
683+
_ = c.connection.close()
683684
return c.cleanupReferences()
684685
}
685686

686687
func (c *Connection) cleanupReferences() error {
687-
err := c.pool.checkIn(c.connection)
688+
err := c.connection.pool.checkIn(c.connection)
688689
if c.cleanupPoolFn != nil {
689690
c.cleanupPoolFn()
690691
c.cleanupPoolFn = nil
@@ -709,14 +710,22 @@ func (c *Connection) ID() string {
709710
if c.connection == nil {
710711
return "<closed>"
711712
}
712-
return c.id
713+
return c.connection.id
714+
}
715+
716+
// ServerConnectionID returns the server connection ID of this connection.
717+
func (c *Connection) ServerConnectionID() *int32 {
718+
if c.connection == nil {
719+
return nil
720+
}
721+
return c.connection.serverConnectionID
713722
}
714723

715724
// Stale returns if the connection is stale.
716725
func (c *Connection) Stale() bool {
717726
c.mu.RLock()
718727
defer c.mu.RUnlock()
719-
return c.pool.stale(c.connection)
728+
return c.connection.pool.stale(c.connection)
720729
}
721730

722731
// Address returns the address of this connection.
@@ -726,27 +735,27 @@ func (c *Connection) Address() address.Address {
726735
if c.connection == nil {
727736
return address.Address("0.0.0.0")
728737
}
729-
return c.addr
738+
return c.connection.addr
730739
}
731740

732741
// LocalAddress returns the local address of the connection
733742
func (c *Connection) LocalAddress() address.Address {
734743
c.mu.RLock()
735744
defer c.mu.RUnlock()
736-
if c.connection == nil || c.nc == nil {
745+
if c.connection == nil || c.connection.nc == nil {
737746
return address.Address("0.0.0.0")
738747
}
739-
return address.Address(c.nc.LocalAddr().String())
748+
return address.Address(c.connection.nc.LocalAddr().String())
740749
}
741750

742751
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
743752
func (c *Connection) PinToCursor() error {
744-
return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor)
753+
return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
745754
}
746755

747756
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
748757
func (c *Connection) PinToTransaction() error {
749-
return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction)
758+
return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
750759
}
751760

752761
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {

x/mongo/driver/topology/connection_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,12 @@ func TestConnection(t *testing.T) {
763763
if !cmp.Equal(got, want) {
764764
t.Errorf("LocalAddresses do not match. got %v; want %v", got, want)
765765
}
766+
767+
want = (*int32)(nil)
768+
got = conn.ServerConnectionID()
769+
if !cmp.Equal(got, want) {
770+
t.Errorf("ServerConnectionIDs do not match. got %v; want %v", got, want)
771+
}
766772
})
767773

768774
t.Run("pinning", func(t *testing.T) {

0 commit comments

Comments
 (0)