Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
// If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read
// concern, write concern, or read preference will be used, respectively.
func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*Session, error) {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

sessArgs, err := mongoutil.NewOptions(opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -454,10 +450,6 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*
}

func (c *Client) endSessions(ctx context.Context) {
if c.sessionPool == nil {
return
}

sessionIDs := c.sessionPool.IDSlice()
op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment).
ServerSelector(&serverselector.ReadPref{ReadPref: readpref.PrimaryPreferred()}).
Expand Down Expand Up @@ -872,10 +864,6 @@ func (c *Client) UseSessionWithOptions(
// documentation).
func (c *Client) Watch(ctx context.Context, pipeline interface{},
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

csConfig := changeStreamConfig{
readConcern: c.readConcern,
readPreference: c.readPreference,
Expand Down
15 changes: 8 additions & 7 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/mongoutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/tag"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
)

Expand All @@ -37,7 +37,7 @@ func setupClient(opts ...options.Lister[options.ClientOptions]) *Client {
integtest.AddTestServerAPIVersion(clientOpts)
opts = append(opts, clientOpts)
}
client, _ := newClient(opts...)
client, _ := Connect(opts...)
return client
}

Expand All @@ -53,11 +53,14 @@ func TestClient(t *testing.T) {
assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
})
t.Run("replace topology error", func(t *testing.T) {
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
client := setupClient()

_, err := client.StartSession()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
topo, ok := client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

_, err = client.ListDatabases(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
Expand All @@ -72,9 +75,7 @@ func TestClient(t *testing.T) {
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("nil document error", func(t *testing.T) {
// manually set session pool to non-nil because Watch will return ErrClientDisconnected
client := setupClient()
client.sessionPool = &session.Pool{}

_, err := client.Watch(bgCtx, nil)
watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
Expand Down
13 changes: 11 additions & 2 deletions mongo/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ package mongo

import (
"bytes"
"context"
"errors"
"testing"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/ptrutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
)

const (
Expand Down Expand Up @@ -78,12 +81,18 @@ func TestCollection(t *testing.T) {
}
compareColls(t, expected, coll)
})
t.Run("replace topology error", func(t *testing.T) {
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
coll := setupColl("foo")
doc := bson.D{}
update := bson.D{{"$update", bson.D{{"x", 1}}}}

_, err := coll.InsertOne(bgCtx, doc)
topo, ok := coll.client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

_, err = coll.InsertOne(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.InsertMany(bgCtx, []interface{}{doc})
Expand Down
14 changes: 10 additions & 4 deletions mongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
Expand Down Expand Up @@ -83,9 +84,16 @@ func TestDatabase(t *testing.T) {
compareDbs(t, expected, got)
})
})
t.Run("replace topology error", func(t *testing.T) {
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
db := setupDb("foo")
err := db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()

topo, ok := db.client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

err = db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = db.Drop(bgCtx)
Expand All @@ -96,9 +104,7 @@ func TestDatabase(t *testing.T) {
})
t.Run("TransientTransactionError label", func(t *testing.T) {
client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
err := client.connect()
defer func() { _ = client.Disconnect(bgCtx) }()
assert.Nil(t, err, "expected nil, got %v", err)

t.Run("negative case of non-transaction", func(t *testing.T) {
var sse topology.ServerSelectionError
Expand Down
Loading