Skip to content

Commit 95034ee

Browse files
author
Isabella Siu
committed
Change driver.KillCursors to not do server selection and fix tests.
GODRIVER-817 GODRIVER-819 Co-authored-by: Divjot Arora <[email protected]> Change-Id: I78b457f32df9ea53ec650e1de3fb6aa75fc78174
1 parent cabca31 commit 95034ee

File tree

8 files changed

+60
-82
lines changed

8 files changed

+60
-82
lines changed

mongo/batch_cursor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
7+
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
78
)
89

910
// batchCursor is the interface implemented by types that can provide batches of document results.
@@ -19,6 +20,9 @@ type batchCursor interface {
1920
// DocumentSequence is only valid until the next call to Next or Close.
2021
Batch() *bsoncore.DocumentSequence
2122

23+
// Server returns a pointer to the cursor's server.
24+
Server() *topology.Server
25+
2226
// Err returns the last error encountered.
2327
Err() error
2428

mongo/change_stream.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,8 @@ func (cs *ChangeStream) Next(ctx context.Context) bool {
453453
}
454454
}
455455

456-
killCursors := command.KillCursors{
457-
NS: cs.ns,
458-
IDs: []int64{cs.ID()},
459-
}
456+
_, _ = driver.KillCursors(ctx, cs.ns, cs.cursor.bc.Server(), cs.ID())
460457

461-
_, _ = driver.KillCursors(ctx, killCursors, cs.client.topology, cs.db.writeSelector)
462458
cs.err = cs.runCommand(ctx, true)
463459
if cs.err != nil {
464460
return false

mongo/change_stream_test.go

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ func (er *errorCursor) Close(ctx context.Context) error {
6363
return nil
6464
}
6565

66+
func killChangeStreamCursor(t *testing.T, cs *ChangeStream) {
67+
_, err := driver.KillCursors(context.Background(), cs.ns, cs.cursor.bc.Server(), cs.ID())
68+
if err != nil {
69+
t.Fatalf("error killing cursor: %v", err)
70+
}
71+
}
72+
6673
func skipIfBelow36(t *testing.T) {
6774
serverVersion, err := getServerVersion(createTestDatabase(t, nil))
6875
require.NoError(t, err)
@@ -355,15 +362,8 @@ func TestChangeStream_ReplicaSet(t *testing.T) {
355362
ensureResumeToken(t, coll, stream)
356363
cs := stream
357364

358-
kc := command.KillCursors{
359-
NS: cs.ns,
360-
IDs: []int64{cs.ID()},
361-
}
362-
363-
_, err := driver.KillCursors(ctx, kc, cs.client.topology, cs.db.writeSelector)
364-
testhelpers.RequireNil(t, err, "error running killCursors cmd: %s", err)
365-
366-
_, err = coll.InsertOne(ctx, doc1)
365+
killChangeStreamCursor(t, cs)
366+
_, err := coll.InsertOne(ctx, doc1)
367367
testhelpers.RequireNil(t, err, "error inserting doc: %s", err)
368368

369369
drainChannels()
@@ -478,13 +478,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) {
478478
cs := stream
479479

480480
// kill cursor to force a resumable error
481-
kc := command.KillCursors{
482-
NS: cs.ns,
483-
IDs: []int64{cs.ID()},
484-
}
485-
486-
_, err = driver.KillCursors(ctx, kc, cs.client.topology, cs.db.writeSelector)
487-
testhelpers.RequireNil(t, err, "error running killCursors cmd: %s", err)
481+
killChangeStreamCursor(t, cs)
488482

489483
adminDb := coll.client.Database("admin")
490484
modeDoc := bsonx.Doc{
@@ -527,14 +521,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) {
527521
cs := stream
528522

529523
// kill cursor to force a resumable error
530-
kc := command.KillCursors{
531-
NS: cs.ns,
532-
IDs: []int64{cs.ID()},
533-
}
534-
535-
_, err = driver.KillCursors(ctx, kc, cs.client.topology, cs.db.writeSelector)
536-
testhelpers.RequireNil(t, err, "error running killCursors cmd: %s", err)
537-
524+
killChangeStreamCursor(t, cs)
538525
drainChannels()
539526
stream.Next(ctx)
540527

@@ -685,14 +672,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) {
685672

686673
// kill the stream's underlying cursor to force a resumeable error
687674
cs := stream
688-
kc := command.KillCursors{
689-
NS: cs.ns,
690-
IDs: []int64{cs.ID()},
691-
}
692-
693-
_, err := driver.KillCursors(ctx, kc, cs.client.topology, cs.db.writeSelector)
694-
testhelpers.RequireNil(t, err, "error running killCursors cmd: %s", err)
695-
675+
killChangeStreamCursor(t, cs)
696676
ensureResumeToken(t, coll, stream)
697677
})
698678
}

mongo/collection_internal_test.go

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727
"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
2828
"github.com/mongodb/mongo-go-driver/x/mongo/driver"
2929
"github.com/mongodb/mongo-go-driver/x/network/command"
30-
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
3130
"github.com/stretchr/testify/assert"
3231
"github.com/stretchr/testify/require"
3332
)
@@ -1544,39 +1543,6 @@ func TestCollection_Find_notFound(t *testing.T) {
15441543
require.False(t, cursor.Next(context.Background()))
15451544
}
15461545

1547-
func killCursor(t *testing.T, c *Cursor, coll *Collection) {
1548-
version, err := getServerVersion(coll.db)
1549-
require.Nil(t, err, "error getting server version: %s", err)
1550-
ns := command.NewNamespace(coll.db.name, coll.name)
1551-
1552-
if compareVersions(t, version, "3.0") > 0 {
1553-
// not legacy
1554-
kc := command.KillCursors{
1555-
NS: ns,
1556-
IDs: []int64{c.ID()},
1557-
}
1558-
1559-
_, err := driver.KillCursors(ctx, kc, coll.client.topology, coll.db.writeSelector)
1560-
require.Nil(t, err, "error killing cursor: %s", err)
1561-
return
1562-
}
1563-
1564-
// legacy
1565-
kc := wiremessage.KillCursors{
1566-
NumberOfCursorIDs: 1,
1567-
CursorIDs: []int64{c.ID()},
1568-
CollectionName: ns.Collection,
1569-
DatabaseName: ns.DB,
1570-
}
1571-
topo := testutil.Topology(t)
1572-
ss, err := topo.SelectServer(ctx, coll.db.writeSelector)
1573-
require.Nil(t, err, "error selecting server: %s", err)
1574-
conn, err := ss.Connection(ctx)
1575-
require.Nil(t, err, "error getting connection: %s", err)
1576-
err = conn.WriteWireMessage(context.Background(), kc)
1577-
require.Nil(t, err, "error writing wire msg: %s", err)
1578-
}
1579-
15801546
func TestCollection_Find_Error(t *testing.T) {
15811547
t.Run("TestInvalidIdentifier", func(t *testing.T) {
15821548
coll := createTestCollection(t, nil, nil)
@@ -1595,7 +1561,11 @@ func TestCollection_Find_Error(t *testing.T) {
15951561
require.True(t, c.Next(context.Background()))
15961562
require.True(t, c.Next(context.Background()))
15971563

1598-
killCursor(t, c, coll)
1564+
_, err = driver.KillCursors(ctx, command.Namespace{
1565+
DB: coll.db.name,
1566+
Collection: coll.name,
1567+
}, c.bc.Server(), c.ID())
1568+
require.NoError(t, err)
15991569
require.False(t, c.Next(context.Background()))
16001570
require.NotNil(t, c.Err())
16011571
_ = c.Close(context.Background())

mongo/sessions_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ func createFuncMap(t *testing.T, dbName string, collName string, monitored bool)
9191
err := db.Drop(ctx)
9292
testhelpers.RequireNil(t, err, "error dropping database after creation: %s", err)
9393

94+
// ensure database exists
95+
_, _ = db.Collection("foo").InsertOne(context.Background(), doc)
96+
9497
coll := db.Collection(collName)
9598
iv := coll.Indexes()
9699

x/mongo/driver/batch_cursor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ func (bc *BatchCursor) Next(ctx context.Context) bool {
174174
// DocumentSequence is only valid until the next call to Next or Close.
175175
func (bc *BatchCursor) Batch() *bsoncore.DocumentSequence { return bc.currentBatch }
176176

177+
// Server returns a pointer to the cursor's server.
178+
func (bc *BatchCursor) Server() *topology.Server { return bc.server }
179+
177180
// Err returns the latest error encountered.
178181
func (bc *BatchCursor) Err() error { return bc.err }
179182

x/mongo/driver/kill_cursors.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,48 @@ package driver
88

99
import (
1010
"context"
11+
"github.com/mongodb/mongo-go-driver/x/network/connection"
12+
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
1113

1214
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
1315
"github.com/mongodb/mongo-go-driver/x/network/command"
14-
"github.com/mongodb/mongo-go-driver/x/network/description"
1516
"github.com/mongodb/mongo-go-driver/x/network/result"
1617
)
1718

1819
// KillCursors handles the full cycle dispatch and execution of an aggregate command against the provided
1920
// topology.
2021
func KillCursors(
2122
ctx context.Context,
22-
cmd command.KillCursors,
23-
topo *topology.Topology,
24-
selector description.ServerSelector,
23+
ns command.Namespace,
24+
server *topology.Server,
25+
cursorID int64,
2526
) (result.KillCursors, error) {
26-
ss, err := topo.SelectServer(ctx, selector)
27-
if err != nil {
28-
return result.KillCursors{}, err
29-
}
30-
desc := ss.Description()
31-
conn, err := ss.Connection(ctx)
27+
desc := server.SelectedDescription()
28+
conn, err := server.Connection(ctx)
3229
if err != nil {
3330
return result.KillCursors{}, err
3431
}
3532
defer conn.Close()
33+
34+
if desc.WireVersion.Max < 4 {
35+
return result.KillCursors{}, legacyKillCursors(ctx, ns, cursorID, conn)
36+
}
37+
38+
cmd := command.KillCursors{
39+
NS: ns,
40+
IDs: []int64{cursorID},
41+
}
42+
3643
return cmd.RoundTrip(ctx, desc, conn)
3744
}
45+
46+
func legacyKillCursors(ctx context.Context, ns command.Namespace, cursorID int64, conn connection.Connection) error {
47+
kc := wiremessage.KillCursors{
48+
NumberOfCursorIDs: 1,
49+
CursorIDs: []int64{cursorID},
50+
CollectionName: ns.Collection,
51+
DatabaseName: ns.DB,
52+
}
53+
54+
return conn.WriteWireMessage(ctx, kc)
55+
}

x/mongo/driver/list_collections_batch_cursor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
10+
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
1011
)
1112

1213
// ListCollectionsBatchCursor is a special batch cursor returned from ListCollections that properly
@@ -83,6 +84,9 @@ func (lcbc *ListCollectionsBatchCursor) Next(ctx context.Context) bool {
8384
// DocumentSequence is only valid until the next call to Next or Close.
8485
func (lcbc *ListCollectionsBatchCursor) Batch() *bsoncore.DocumentSequence { return lcbc.currentBatch }
8586

87+
// Server returns a pointer to the cursor's server.
88+
func (lcbc *ListCollectionsBatchCursor) Server() *topology.Server { return lcbc.bc.server }
89+
8690
// Err returns the latest error encountered.
8791
func (lcbc *ListCollectionsBatchCursor) Err() error {
8892
if lcbc.err != nil {

0 commit comments

Comments
 (0)