Skip to content

Commit 67f3381

Browse files
GODRIVER-1605 Add SetBatchSize to mongo.Cursor (#1201)
Co-authored-by: Matt Dale <[email protected]>
1 parent cb5ffcf commit 67f3381

File tree

6 files changed

+108
-10
lines changed

6 files changed

+108
-10
lines changed

mongo/batch_cursor.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ type batchCursor interface {
3434

3535
// Close closes the cursor.
3636
Close(context.Context) error
37+
38+
// The SetBatchSize method is a modifier function used to adjust the
39+
// batch size of the cursor that implements it.
40+
SetBatchSize(int32)
3741
}
3842

3943
// changeStreamCursor is the interface implemented by batch cursors that also provide the functionality for retrieving

mongo/cursor.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,14 @@ func (c *Cursor) closeImplicitSession() {
314314
}
315315
}
316316

317+
// SetBatchSize sets the number of documents to fetch from the database with
318+
// each iteration of the cursor's "Next" method. Note that some operations set
319+
// an initial cursor batch size, so this setting only affects subsequent
320+
// document batches fetched from the database.
321+
func (c *Cursor) SetBatchSize(batchSize int32) {
322+
c.bc.SetBatchSize(batchSize)
323+
}
324+
317325
// BatchCursorFromCursor returns a driver.BatchCursor for the given Cursor. If there is no underlying
318326
// driver.BatchCursor, nil is returned.
319327
//

mongo/cursor_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ func (tbc *testBatchCursor) Close(context.Context) error {
8686
return nil
8787
}
8888

89+
func (tbc *testBatchCursor) SetBatchSize(int32) {}
90+
8991
func TestCursor(t *testing.T) {
9092
t.Run("loops until docs available", func(t *testing.T) {})
9193
t.Run("returns false on context cancellation", func(t *testing.T) {})

x/mongo/driver/batch_cursor.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,23 +308,37 @@ func (bc *BatchCursor) KillCursor(ctx context.Context) error {
308308
}.Execute(ctx)
309309
}
310310

311+
// calcGetMoreBatchSize calculates the number of documents to return in the
312+
// response of a "getMore" operation based on the given limit, batchSize, and
313+
// number of documents already returned. Returns false if a non-trivial limit is
314+
// lower than or equal to the number of documents already returned.
315+
func calcGetMoreBatchSize(bc BatchCursor) (int32, bool) {
316+
gmBatchSize := bc.batchSize
317+
318+
// Account for legacy operations that don't support setting a limit.
319+
if bc.limit != 0 && bc.numReturned+bc.batchSize >= bc.limit {
320+
gmBatchSize = bc.limit - bc.numReturned
321+
if gmBatchSize <= 0 {
322+
return gmBatchSize, false
323+
}
324+
}
325+
326+
return gmBatchSize, true
327+
}
328+
311329
func (bc *BatchCursor) getMore(ctx context.Context) {
312330
bc.clearBatch()
313331
if bc.id == 0 {
314332
return
315333
}
316334

317-
// Required for legacy operations which don't support limit.
318-
numToReturn := bc.batchSize
319-
if bc.limit != 0 && bc.numReturned+bc.batchSize >= bc.limit {
320-
numToReturn = bc.limit - bc.numReturned
321-
if numToReturn <= 0 {
322-
err := bc.Close(ctx)
323-
if err != nil {
324-
bc.err = err
325-
}
326-
return
335+
numToReturn, ok := calcGetMoreBatchSize(*bc)
336+
if !ok {
337+
if err := bc.Close(ctx); err != nil {
338+
bc.err = err
327339
}
340+
341+
return
328342
}
329343

330344
bc.err = Operation{

x/mongo/driver/batch_cursor_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ import (
1313
)
1414

1515
func TestBatchCursor(t *testing.T) {
16+
t.Parallel()
17+
1618
t.Run("setBatchSize", func(t *testing.T) {
19+
t.Parallel()
20+
1721
var size int32
1822
bc := &BatchCursor{
1923
batchSize: size,
@@ -24,4 +28,65 @@ func TestBatchCursor(t *testing.T) {
2428
bc.SetBatchSize(size)
2529
assert.Equal(t, size, bc.batchSize, "expected batchSize %v, got %v", size, bc.batchSize)
2630
})
31+
32+
t.Run("calcGetMoreBatchSize", func(t *testing.T) {
33+
t.Parallel()
34+
35+
for _, tcase := range []struct {
36+
name string
37+
size, limit, numReturned, expected int32
38+
ok bool
39+
}{
40+
{
41+
name: "empty",
42+
expected: 0,
43+
ok: true,
44+
},
45+
{
46+
name: "batchSize NEQ 0",
47+
size: 4,
48+
expected: 4,
49+
ok: true,
50+
},
51+
{
52+
name: "limit NEQ 0",
53+
limit: 4,
54+
expected: 0,
55+
ok: true,
56+
},
57+
{
58+
name: "limit NEQ and batchSize + numReturned EQ limit",
59+
size: 4,
60+
limit: 8,
61+
numReturned: 4,
62+
expected: 4,
63+
ok: true,
64+
},
65+
{
66+
name: "limit makes batchSize negative",
67+
numReturned: 4,
68+
limit: 2,
69+
expected: -2,
70+
ok: false,
71+
},
72+
} {
73+
tcase := tcase
74+
t.Run(tcase.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
bc := &BatchCursor{
78+
limit: tcase.limit,
79+
batchSize: tcase.size,
80+
numReturned: tcase.numReturned,
81+
}
82+
83+
bc.SetBatchSize(tcase.size)
84+
85+
size, ok := calcGetMoreBatchSize(*bc)
86+
87+
assert.Equal(t, tcase.expected, size, "expected batchSize %v, got %v", tcase.expected, size)
88+
assert.Equal(t, tcase.ok, ok, "expected ok %v, got %v", tcase.ok, ok)
89+
})
90+
}
91+
})
2792
}

x/mongo/driver/list_collections_batch_cursor.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ func (*ListCollectionsBatchCursor) projectNameElement(rawDoc bsoncore.Document)
127127
filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems)
128128
return filteredDoc, nil
129129
}
130+
131+
// SetBatchSize sets the batchSize for future getMores.
132+
func (lcbc *ListCollectionsBatchCursor) SetBatchSize(size int32) {
133+
lcbc.bc.SetBatchSize(size)
134+
}

0 commit comments

Comments
 (0)