Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"net/http"
"sync/atomic"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
Expand Down Expand Up @@ -73,6 +74,7 @@ type Client struct {
timeout *time.Duration
httpClient *http.Client
logger *logger.Logger
closed atomic.Value

// client-side encryption fields
keyVaultClientFLE *Client
Expand Down Expand Up @@ -250,6 +252,8 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
return nil, fmt.Errorf("invalid logger options: %w", err)
}

client.closed.Store(false)

return client, nil
}

Expand Down Expand Up @@ -311,6 +315,10 @@ func (c *Client) connect() error {
// or write operations. If this method returns with no errors, all connections
// associated with this Client have been closed.
func (c *Client) Disconnect(ctx context.Context) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if c.logger != nil {
defer c.logger.Close()
}
Expand Down Expand Up @@ -350,6 +358,8 @@ func (c *Client) Disconnect(ctx context.Context) error {
c.cryptFLE.Close()
}

c.closed.Store(true)

if disconnector, ok := c.deployment.(driver.Disconnector); ok {
return replaceErrors(disconnector.Disconnect(ctx))
}
Expand All @@ -369,6 +379,10 @@ func (c *Client) Disconnect(ctx context.Context) error {
// Using Ping reduces application resilience because applications starting up will error if the server is temporarily
// unavailable or is failing over (e.g. during autoscaling due to a load spike).
func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if ctx == nil {
ctx = context.Background()
}
Expand Down Expand Up @@ -396,10 +410,6 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
// If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read
// concern, write concern, or read preference will be used, respectively.
func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*Session, error) {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

sessArgs, err := mongoutil.NewOptions(opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -454,10 +464,6 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*
}

func (c *Client) endSessions(ctx context.Context) {
if c.sessionPool == nil {
return
}

sessionIDs := c.sessionPool.IDSlice()
op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment).
ServerSelector(&serverselector.ReadPref{ReadPref: readpref.PrimaryPreferred()}).
Expand Down Expand Up @@ -872,7 +878,7 @@ func (c *Client) UseSessionWithOptions(
// documentation).
func (c *Client) Watch(ctx context.Context, pipeline interface{},
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
if c.sessionPool == nil {
if c.closed.Load().(bool) {
return nil, ErrClientDisconnected
}

Expand Down
21 changes: 7 additions & 14 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/tag"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
)

Expand All @@ -37,7 +36,7 @@ func setupClient(opts ...options.Lister[options.ClientOptions]) *Client {
integtest.AddTestServerAPIVersion(clientOpts)
opts = append(opts, clientOpts)
}
client, _ := newClient(opts...)
client, _ := Connect(opts...)
return client
}

Expand All @@ -53,28 +52,22 @@ func TestClient(t *testing.T) {
assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
})
t.Run("replace topology error", func(t *testing.T) {
t.Run("client disconnect error", func(t *testing.T) {
client := setupClient()
assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool))

_, err := client.StartSession()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.ListDatabases(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
err := client.Disconnect(bgCtx)
assert.Equal(t, nil, err, "expected nil, got %v", err)
assert.Equal(t, true, client.closed.Load().(bool), "expected error %v, got %v", true, client.closed.Load().(bool))

err = client.Ping(bgCtx, nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = client.Disconnect(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.Watch(bgCtx, []bson.D{})
_, err = client.Watch(bgCtx, nil, nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("nil document error", func(t *testing.T) {
// manually set session pool to non-nil because Watch will return ErrClientDisconnected
client := setupClient()
client.sessionPool = &session.Pool{}

_, err := client.Watch(bgCtx, nil)
watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
Expand Down
53 changes: 0 additions & 53 deletions mongo/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,59 +78,6 @@ func TestCollection(t *testing.T) {
}
compareColls(t, expected, coll)
})
t.Run("replace topology error", func(t *testing.T) {
coll := setupColl("foo")
doc := bson.D{}
update := bson.D{{"$update", bson.D{{"x", 1}}}}

_, err := coll.InsertOne(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.InsertMany(bgCtx, []interface{}{doc})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.DeleteOne(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.DeleteMany(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.UpdateOne(bgCtx, doc, update)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.UpdateMany(bgCtx, doc, update)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.ReplaceOne(bgCtx, doc, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.Aggregate(bgCtx, Pipeline{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.EstimatedDocumentCount(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.CountDocuments(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.Distinct(bgCtx, "x", doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.Find(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOne(bgCtx, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndDelete(bgCtx, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndReplace(bgCtx, doc, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndUpdate(bgCtx, doc, update).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("database accessor", func(t *testing.T) {
coll := setupColl("bar")
dbName := coll.Database().Name()
Expand Down
14 changes: 1 addition & 13 deletions mongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,10 @@ func TestDatabase(t *testing.T) {
compareDbs(t, expected, got)
})
})
t.Run("replace topology error", func(t *testing.T) {
db := setupDb("foo")
err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = db.Drop(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = db.ListCollections(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("TransientTransactionError label", func(t *testing.T) {
client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
err := client.connect()
assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool))
defer func() { _ = client.Disconnect(bgCtx) }()
assert.Nil(t, err, "expected nil, got %v", err)

t.Run("negative case of non-transaction", func(t *testing.T) {
var sse topology.ServerSelectionError
Expand Down
Loading