From 2fa662e2ae8f132494c17eea23f9305c78ed253e Mon Sep 17 00:00:00 2001 From: Joy Wang Date: Fri, 13 Sep 2024 17:01:21 -0400 Subject: [PATCH 1/4] removed sessionPool nil checks, add closed atomic value in client struct, alter tests --- mongo/client.go | 24 +++++++++++++++--------- mongo/client_test.go | 21 +++++++-------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index c9859eff23..12aab3a786 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "net/http" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -73,6 +74,7 @@ type Client struct { timeout *time.Duration httpClient *http.Client logger *logger.Logger + closed atomic.Value // client-side encryption fields keyVaultClientFLE *Client @@ -250,6 +252,8 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { return nil, fmt.Errorf("invalid logger options: %w", err) } + client.closed.Store(false) + return client, nil } @@ -311,6 +315,10 @@ func (c *Client) connect() error { // or write operations. If this method returns with no errors, all connections // associated with this Client have been closed. func (c *Client) Disconnect(ctx context.Context) error { + if c.closed.Load().(bool) { + return ErrClientDisconnected + } + if c.logger != nil { defer c.logger.Close() } @@ -350,6 +358,8 @@ func (c *Client) Disconnect(ctx context.Context) error { c.cryptFLE.Close() } + c.closed.Store(true) + if disconnector, ok := c.deployment.(driver.Disconnector); ok { return replaceErrors(disconnector.Disconnect(ctx)) } @@ -369,6 +379,10 @@ func (c *Client) Disconnect(ctx context.Context) error { // Using Ping reduces application resilience because applications starting up will error if the server is temporarily // unavailable or is failing over (e.g. during autoscaling due to a load spike). func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { + if c.closed.Load().(bool) { + return ErrClientDisconnected + } + if ctx == nil { ctx = context.Background() } @@ -396,10 +410,6 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { // If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read // concern, write concern, or read preference will be used, respectively. func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*Session, error) { - if c.sessionPool == nil { - return nil, ErrClientDisconnected - } - sessArgs, err := mongoutil.NewOptions(opts...) if err != nil { return nil, err @@ -454,10 +464,6 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (* } func (c *Client) endSessions(ctx context.Context) { - if c.sessionPool == nil { - return - } - sessionIDs := c.sessionPool.IDSlice() op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment). ServerSelector(&serverselector.ReadPref{ReadPref: readpref.PrimaryPreferred()}). @@ -872,7 +878,7 @@ func (c *Client) UseSessionWithOptions( // documentation). func (c *Client) Watch(ctx context.Context, pipeline interface{}, opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { - if c.sessionPool == nil { + if c.closed.Load().(bool) { return nil, ErrClientDisconnected } diff --git a/mongo/client_test.go b/mongo/client_test.go index 72e3ee0962..e5db2f76dc 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -25,7 +25,6 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/tag" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt" - "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" ) @@ -37,7 +36,7 @@ func setupClient(opts ...options.Lister[options.ClientOptions]) *Client { integtest.AddTestServerAPIVersion(clientOpts) opts = append(opts, clientOpts) } - client, _ := newClient(opts...) + client, _ := Connect(opts...) return client } @@ -53,28 +52,22 @@ func TestClient(t *testing.T) { assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name()) assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client()) }) - t.Run("replace topology error", func(t *testing.T) { + t.Run("client disconnect error", func(t *testing.T) { client := setupClient() + assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool)) - _, err := client.StartSession() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = client.ListDatabases(bgCtx, bson.D{}) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + err := client.Disconnect(bgCtx) + assert.Equal(t, nil, err, "expected nil, got %v", err) + assert.Equal(t, true, client.closed.Load().(bool), "expected error %v, got %v", true, client.closed.Load().(bool)) err = client.Ping(bgCtx, nil) assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - err = client.Disconnect(bgCtx) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = client.Watch(bgCtx, []bson.D{}) + _, err = client.Watch(bgCtx, nil, nil) assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) }) t.Run("nil document error", func(t *testing.T) { - // manually set session pool to non-nil because Watch will return ErrClientDisconnected client := setupClient() - client.sessionPool = &session.Pool{} _, err := client.Watch(bgCtx, nil) watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") From 0460e7ce29e187a15df9836363ceb59f25fb79c8 Mon Sep 17 00:00:00 2001 From: Joy Wang Date: Fri, 13 Sep 2024 17:21:48 -0400 Subject: [PATCH 2/4] remove tests that expect ErrClientDisconnected in collection and database --- mongo/collection_test.go | 53 ---------------------------------------- mongo/database_test.go | 11 --------- 2 files changed, 64 deletions(-) diff --git a/mongo/collection_test.go b/mongo/collection_test.go index 648f04a46f..53d856c860 100644 --- a/mongo/collection_test.go +++ b/mongo/collection_test.go @@ -78,59 +78,6 @@ func TestCollection(t *testing.T) { } compareColls(t, expected, coll) }) - t.Run("replace topology error", func(t *testing.T) { - coll := setupColl("foo") - doc := bson.D{} - update := bson.D{{"$update", bson.D{{"x", 1}}}} - - _, err := coll.InsertOne(bgCtx, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.InsertMany(bgCtx, []interface{}{doc}) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.DeleteOne(bgCtx, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.DeleteMany(bgCtx, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.UpdateOne(bgCtx, doc, update) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.UpdateMany(bgCtx, doc, update) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.ReplaceOne(bgCtx, doc, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.Aggregate(bgCtx, Pipeline{}) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.EstimatedDocumentCount(bgCtx) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.CountDocuments(bgCtx, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = coll.Distinct(bgCtx, "x", doc).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = coll.Find(bgCtx, doc) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = coll.FindOne(bgCtx, doc).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = coll.FindOneAndDelete(bgCtx, doc).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = coll.FindOneAndReplace(bgCtx, doc, doc).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = coll.FindOneAndUpdate(bgCtx, doc, update).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - }) t.Run("database accessor", func(t *testing.T) { coll := setupColl("bar") dbName := coll.Database().Name() diff --git a/mongo/database_test.go b/mongo/database_test.go index 6b9b8df319..f209cec542 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -83,17 +83,6 @@ func TestDatabase(t *testing.T) { compareDbs(t, expected, got) }) }) - t.Run("replace topology error", func(t *testing.T) { - db := setupDb("foo") - err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err() - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - err = db.Drop(bgCtx) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - - _, err = db.ListCollections(bgCtx, bson.D{}) - assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - }) t.Run("TransientTransactionError label", func(t *testing.T) { client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second)) err := client.connect() From 8210c5bd28285697655310777089c1445a174670 Mon Sep 17 00:00:00 2001 From: Joy Wang Date: Fri, 13 Sep 2024 17:34:25 -0400 Subject: [PATCH 3/4] remove missed test for database --- mongo/database_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mongo/database_test.go b/mongo/database_test.go index f209cec542..5c7186d531 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -85,9 +85,8 @@ func TestDatabase(t *testing.T) { }) t.Run("TransientTransactionError label", func(t *testing.T) { client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second)) - err := client.connect() + assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool)) defer func() { _ = client.Disconnect(bgCtx) }() - assert.Nil(t, err, "expected nil, got %v", err) t.Run("negative case of non-transaction", func(t *testing.T) { var sse topology.ServerSelectionError From 3a7ec7630bf6ce90c466d992e1f3e55a0a6d6975 Mon Sep 17 00:00:00 2001 From: Joy Wang Date: Mon, 16 Sep 2024 17:04:29 -0400 Subject: [PATCH 4/4] removing atomic value and altering tests for replaceErrors for disconnected topology --- mongo/client.go | 18 ------------ mongo/client_test.go | 20 +++++++++---- mongo/collection_test.go | 62 ++++++++++++++++++++++++++++++++++++++++ mongo/database_test.go | 20 ++++++++++++- 4 files changed, 95 insertions(+), 25 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index 12aab3a786..d7dc15fcc7 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "net/http" - "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -74,7 +73,6 @@ type Client struct { timeout *time.Duration httpClient *http.Client logger *logger.Logger - closed atomic.Value // client-side encryption fields keyVaultClientFLE *Client @@ -252,8 +250,6 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { return nil, fmt.Errorf("invalid logger options: %w", err) } - client.closed.Store(false) - return client, nil } @@ -315,10 +311,6 @@ func (c *Client) connect() error { // or write operations. If this method returns with no errors, all connections // associated with this Client have been closed. func (c *Client) Disconnect(ctx context.Context) error { - if c.closed.Load().(bool) { - return ErrClientDisconnected - } - if c.logger != nil { defer c.logger.Close() } @@ -358,8 +350,6 @@ func (c *Client) Disconnect(ctx context.Context) error { c.cryptFLE.Close() } - c.closed.Store(true) - if disconnector, ok := c.deployment.(driver.Disconnector); ok { return replaceErrors(disconnector.Disconnect(ctx)) } @@ -379,10 +369,6 @@ func (c *Client) Disconnect(ctx context.Context) error { // Using Ping reduces application resilience because applications starting up will error if the server is temporarily // unavailable or is failing over (e.g. during autoscaling due to a load spike). func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error { - if c.closed.Load().(bool) { - return ErrClientDisconnected - } - if ctx == nil { ctx = context.Background() } @@ -878,10 +864,6 @@ func (c *Client) UseSessionWithOptions( // documentation). func (c *Client) Watch(ctx context.Context, pipeline interface{}, opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) { - if c.closed.Load().(bool) { - return nil, ErrClientDisconnected - } - csConfig := changeStreamConfig{ readConcern: c.readConcern, readPreference: c.readPreference, diff --git a/mongo/client_test.go b/mongo/client_test.go index e5db2f76dc..70e2075124 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -19,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" @@ -52,18 +53,25 @@ func TestClient(t *testing.T) { assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name()) assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client()) }) - t.Run("client disconnect error", func(t *testing.T) { + t.Run("replaceErrors for disconnected topology", func(t *testing.T) { client := setupClient() - assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool)) - err := client.Disconnect(bgCtx) - assert.Equal(t, nil, err, "expected nil, got %v", err) - assert.Equal(t, true, client.closed.Load().(bool), "expected error %v, got %v", true, client.closed.Load().(bool)) + topo, ok := client.deployment.(*topology.Topology) + require.True(t, ok, "client deployment is not a topology") + + err := topo.Disconnect(context.Background()) + require.NoError(t, err) + + _, err = client.ListDatabases(bgCtx, bson.D{}) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) err = client.Ping(bgCtx, nil) assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - _, err = client.Watch(bgCtx, nil, nil) + err = client.Disconnect(bgCtx) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = client.Watch(bgCtx, []bson.D{}) assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) }) t.Run("nil document error", func(t *testing.T) { diff --git a/mongo/collection_test.go b/mongo/collection_test.go index 53d856c860..268f1c8c0e 100644 --- a/mongo/collection_test.go +++ b/mongo/collection_test.go @@ -8,17 +8,20 @@ package mongo import ( "bytes" + "context" "errors" "testing" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/ptrutil" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" ) const ( @@ -78,6 +81,65 @@ func TestCollection(t *testing.T) { } compareColls(t, expected, coll) }) + t.Run("replaceErrors for disconnected topology", func(t *testing.T) { + coll := setupColl("foo") + doc := bson.D{} + update := bson.D{{"$update", bson.D{{"x", 1}}}} + + topo, ok := coll.client.deployment.(*topology.Topology) + require.True(t, ok, "client deployment is not a topology") + + err := topo.Disconnect(context.Background()) + require.NoError(t, err) + + _, err = coll.InsertOne(bgCtx, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.InsertMany(bgCtx, []interface{}{doc}) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.DeleteOne(bgCtx, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.DeleteMany(bgCtx, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.UpdateOne(bgCtx, doc, update) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.UpdateMany(bgCtx, doc, update) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.ReplaceOne(bgCtx, doc, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.Aggregate(bgCtx, Pipeline{}) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.EstimatedDocumentCount(bgCtx) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.CountDocuments(bgCtx, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = coll.Distinct(bgCtx, "x", doc).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = coll.Find(bgCtx, doc) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = coll.FindOne(bgCtx, doc).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = coll.FindOneAndDelete(bgCtx, doc).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = coll.FindOneAndReplace(bgCtx, doc, doc).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = coll.FindOneAndUpdate(bgCtx, doc, update).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + }) t.Run("database accessor", func(t *testing.T) { coll := setupColl("bar") dbName := coll.Database().Name() diff --git a/mongo/database_test.go b/mongo/database_test.go index 5c7186d531..0c324b9b09 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" @@ -83,9 +84,26 @@ func TestDatabase(t *testing.T) { compareDbs(t, expected, got) }) }) + t.Run("replaceErrors for disconnected topology", func(t *testing.T) { + db := setupDb("foo") + + topo, ok := db.client.deployment.(*topology.Topology) + require.True(t, ok, "client deployment is not a topology") + + err := topo.Disconnect(context.Background()) + require.NoError(t, err) + + err = db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err() + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + err = db.Drop(bgCtx) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + + _, err = db.ListCollections(bgCtx, bson.D{}) + assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) + }) t.Run("TransientTransactionError label", func(t *testing.T) { client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second)) - assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool)) defer func() { _ = client.Disconnect(bgCtx) }() t.Run("negative case of non-transaction", func(t *testing.T) {