Skip to content

Commit 506e97a

Browse files
author
Divjot Arora
committed
Add All method to mongo.Cursor
GODRIVER-916 Change-Id: I6eca805d21a1950b6cee808248ec6edbd9dcf8bc
1 parent 57b22f2 commit 506e97a

File tree

4 files changed

+199
-6
lines changed

4 files changed

+199
-6
lines changed

mongo/cursor.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"errors"
1212
"io"
13+
"reflect"
1314

1415
"go.mongodb.org/mongo-driver/bson"
1516
"go.mongodb.org/mongo-driver/bson/bsoncodec"
@@ -128,3 +129,68 @@ func (c *Cursor) Err() error { return c.err }
128129

129130
// Close closes this cursor.
130131
func (c *Cursor) Close(ctx context.Context) error { return c.bc.Close(ctx) }
132+
133+
// All iterates the cursor and decodes each document into results.
134+
// The results parameter must be a pointer to a slice. The slice pointed to by results will be completely overwritten.
135+
// If the cursor has been iterated, any previously iterated documents will not be included in results.
136+
func (c *Cursor) All(ctx context.Context, results interface{}) error {
137+
resultsVal := reflect.ValueOf(results)
138+
if resultsVal.Kind() != reflect.Ptr {
139+
return errors.New("results argument must be a pointer to a slice")
140+
}
141+
142+
sliceVal := resultsVal.Elem()
143+
elementType := sliceVal.Type().Elem()
144+
var index int
145+
var err error
146+
147+
batch := c.batch // exhaust the current batch before iterating the batch cursor
148+
for {
149+
sliceVal, index, err = c.addFromBatch(sliceVal, elementType, batch, index)
150+
if err != nil {
151+
return err
152+
}
153+
154+
if !c.bc.Next(ctx) {
155+
break
156+
}
157+
158+
batch = c.bc.Batch()
159+
}
160+
161+
if err = c.bc.Err(); err != nil {
162+
return err
163+
}
164+
165+
resultsVal.Elem().Set(sliceVal.Slice(0, index))
166+
return nil
167+
}
168+
169+
// addFromBatch adds all documents from batch to sliceVal starting at the given index. It returns the new slice value,
170+
// the next empty index in the slice, and an error if one occurs.
171+
func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.DocumentSequence,
172+
index int) (reflect.Value, int, error) {
173+
174+
docs, err := batch.Documents()
175+
if err != nil {
176+
return sliceVal, index, err
177+
}
178+
179+
for _, doc := range docs {
180+
if sliceVal.Len() == index {
181+
// slice is full
182+
newElem := reflect.New(elemType)
183+
sliceVal = reflect.Append(sliceVal, newElem.Elem())
184+
sliceVal = sliceVal.Slice(0, sliceVal.Cap())
185+
}
186+
187+
currElem := sliceVal.Index(index).Addr().Interface()
188+
if err = bson.UnmarshalWithRegistry(c.registry, doc, currElem); err != nil {
189+
return sliceVal, index, err
190+
}
191+
192+
index++
193+
}
194+
195+
return sliceVal, index, nil
196+
}

mongo/cursor_test.go

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,137 @@
11
package mongo
22

3-
import "testing"
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
"go.mongodb.org/mongo-driver/bson"
9+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
10+
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
11+
)
12+
13+
type testBatchCursor struct {
14+
batches []*bsoncore.DocumentSequence
15+
batch *bsoncore.DocumentSequence
16+
}
17+
18+
func newTestBatchCursor(numBatches, batchSize int) *testBatchCursor {
19+
batches := make([]*bsoncore.DocumentSequence, 0, numBatches)
20+
21+
counter := 0
22+
for batch := 0; batch < numBatches; batch++ {
23+
var docSequence []byte
24+
25+
for doc := 0; doc < batchSize; doc++ {
26+
var elem []byte
27+
elem = bsoncore.AppendInt32Element(elem, "foo", int32(counter))
28+
counter++
29+
30+
var doc []byte
31+
doc = bsoncore.BuildDocumentFromElements(doc, elem)
32+
docSequence = append(docSequence, doc...)
33+
}
34+
35+
batches = append(batches, &bsoncore.DocumentSequence{
36+
Style: bsoncore.SequenceStyle,
37+
Data: docSequence,
38+
})
39+
}
40+
41+
return &testBatchCursor{
42+
batches: batches,
43+
}
44+
}
45+
46+
func (tbc *testBatchCursor) ID() int64 {
47+
if len(tbc.batches) == 0 {
48+
return 0 // cursor exhausted
49+
}
50+
51+
return 10
52+
}
53+
54+
func (tbc *testBatchCursor) Next(context.Context) bool {
55+
if len(tbc.batches) == 0 {
56+
return false
57+
}
58+
59+
tbc.batch = tbc.batches[0]
60+
tbc.batches = tbc.batches[1:]
61+
return true
62+
}
63+
64+
func (tbc *testBatchCursor) Batch() *bsoncore.DocumentSequence {
65+
return tbc.batch
66+
}
67+
68+
func (tbc *testBatchCursor) Server() *topology.Server {
69+
return nil
70+
}
71+
72+
func (tbc *testBatchCursor) Err() error {
73+
return nil
74+
}
75+
76+
func (tbc *testBatchCursor) Close(context.Context) error {
77+
return nil
78+
}
479

580
func TestCursor(t *testing.T) {
681
t.Run("loops until docs available", func(t *testing.T) {})
782
t.Run("returns false on context cancellation", func(t *testing.T) {})
883
t.Run("returns false if error occurred", func(t *testing.T) {})
984
t.Run("returns false if ID is zero and no more docs", func(t *testing.T) {})
85+
86+
t.Run("TestAll", func(t *testing.T) {
87+
t.Run("errors if argument is not pointer to slice", func(t *testing.T) {
88+
cursor, err := newCursor(newTestBatchCursor(1, 5), nil)
89+
require.Nil(t, err)
90+
err = cursor.All(context.Background(), []bson.D{})
91+
require.NotNil(t, err)
92+
})
93+
94+
t.Run("fills slice with all documents", func(t *testing.T) {
95+
cursor, err := newCursor(newTestBatchCursor(1, 5), nil)
96+
require.Nil(t, err)
97+
98+
var docs []bson.D
99+
err = cursor.All(context.Background(), &docs)
100+
require.Nil(t, err)
101+
require.Equal(t, 5, len(docs))
102+
103+
for index, doc := range docs {
104+
require.Equal(t, doc, bson.D{{"foo", int32(index)}})
105+
}
106+
})
107+
108+
t.Run("decodes each document into slice type", func(t *testing.T) {
109+
cursor, err := newCursor(newTestBatchCursor(1, 5), nil)
110+
require.Nil(t, err)
111+
112+
type Document struct {
113+
Foo int32 `bson:"foo"`
114+
}
115+
var docs []Document
116+
err = cursor.All(context.Background(), &docs)
117+
require.Nil(t, err)
118+
require.Equal(t, 5, len(docs))
119+
120+
for index, doc := range docs {
121+
require.Equal(t, doc, Document{Foo: int32(index)})
122+
}
123+
})
124+
125+
t.Run("multiple batches are included", func(t *testing.T) {
126+
cursor, err := newCursor(newTestBatchCursor(2, 5), nil)
127+
var docs []bson.D
128+
err = cursor.All(context.Background(), &docs)
129+
require.Nil(t, err)
130+
require.Equal(t, 10, len(docs))
131+
132+
for index, doc := range docs {
133+
require.Equal(t, doc, bson.D{{"foo", int32(index)}})
134+
}
135+
})
136+
})
10137
}

x/bsonx/bsoncore/document_sequence.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ func (ds *DocumentSequence) ResetIterator() {
8383
ds.Pos = 0
8484
}
8585

86-
// documents returns a slice of the documents. If nil either the Data field is also nil or could not
86+
// Documents returns a slice of the documents. If nil either the Data field is also nil or could not
8787
// be properly read.
88-
func (ds *DocumentSequence) documents() ([]Document, error) {
88+
func (ds *DocumentSequence) Documents() ([]Document, error) {
8989
if ds == nil {
9090
return nil, nil
9191
}

x/bsonx/bsoncore/document_sequence_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func TestDocumentSequence(t *testing.T) {
103103
Style: tc.style,
104104
Data: tc.data,
105105
}
106-
documents, err := ds.documents()
106+
documents, err := ds.Documents()
107107
if !cmp.Equal(documents, tc.documents) {
108108
t.Errorf("Documents do not match. got %v; want %v", documents, tc.documents)
109109
}
@@ -252,7 +252,7 @@ func TestDocumentSequence(t *testing.T) {
252252
Style: tc.style,
253253
Data: tc.data,
254254
}
255-
docs, err := ds.documents()
255+
docs, err := ds.Documents()
256256
if err != nil {
257257
t.Fatalf("Unexpected error: %v", err)
258258
}
@@ -376,7 +376,7 @@ func TestDocumentSequence(t *testing.T) {
376376
t.Run("Documents", func(t *testing.T) {
377377
defer capturePanic()
378378
var ds *DocumentSequence
379-
_, _ = ds.documents()
379+
_, _ = ds.Documents()
380380
})
381381
t.Run("Next", func(t *testing.T) {
382382
defer capturePanic()

0 commit comments

Comments
 (0)