Skip to content

Commit 2fb348b

Browse files
committed
Disallow nil documents and filters in mongo pkg
GODRIVER 659 Change-Id: Ie85acc9d8ee7da7128b546bfbb95aebbb15e361f
1 parent 8051092 commit 2fb348b

9 files changed

+183
-57
lines changed

mongo/client_internal_test.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package mongo
88

99
import (
1010
"context"
11+
"errors"
1112
"os"
1213
"path"
1314
"testing"
@@ -141,7 +142,7 @@ func TestClientRegistryPassedToCursors(t *testing.T) {
141142
_, err = coll.InsertOne(ctx, NewCodec{ID: 10})
142143
require.NoError(t, err)
143144

144-
c, err := coll.Find(ctx, nil)
145+
c, err := coll.Find(ctx, bsonx.Doc{})
145146
require.NoError(t, err)
146147

147148
require.True(t, c.Next(ctx))
@@ -285,7 +286,7 @@ func TestClient_ReplaceTopologyError(t *testing.T) {
285286
_, err = c.StartSession()
286287
require.Equal(t, err, ErrClientDisconnected)
287288

288-
_, err = c.ListDatabases(ctx, nil)
289+
_, err = c.ListDatabases(ctx, bsonx.Doc{})
289290
require.Equal(t, err, ErrClientDisconnected)
290291

291292
err = c.Ping(ctx, nil)
@@ -314,7 +315,7 @@ func TestClient_ListDatabases_noFilter(t *testing.T) {
314315
)
315316
require.NoError(t, err)
316317

317-
dbs, err := c.ListDatabases(context.Background(), nil)
318+
dbs, err := c.ListDatabases(context.Background(), bsonx.Doc{})
318319
require.NoError(t, err)
319320
found := false
320321

@@ -378,7 +379,7 @@ func TestClient_ListDatabaseNames_noFilter(t *testing.T) {
378379
)
379380
require.NoError(t, err)
380381

381-
dbs, err := c.ListDatabaseNames(context.Background(), nil)
382+
dbs, err := c.ListDatabaseNames(context.Background(), bsonx.Doc{})
382383
found := false
383384

384385
for _, name := range dbs {
@@ -421,6 +422,21 @@ func TestClient_ListDatabaseNames_filter(t *testing.T) {
421422
require.Equal(t, dbName, dbs[0])
422423
}
423424

425+
func TestClient_NilDocumentError(t *testing.T) {
426+
t.Parallel()
427+
428+
c := createTestClient(t)
429+
430+
_, err := c.Watch(context.Background(), nil)
431+
require.Equal(t, err, errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"))
432+
433+
_, err = c.ListDatabases(context.Background(), nil)
434+
require.Equal(t, err, ErrNilDocument)
435+
436+
_, err = c.ListDatabaseNames(context.Background(), nil)
437+
require.Equal(t, err, ErrNilDocument)
438+
}
439+
424440
func TestClient_ReadPreference(t *testing.T) {
425441
t.Parallel()
426442

mongo/collection.go

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel,
149149
opts ...*options.BulkWriteOptions) (*BulkWriteResult, error) {
150150

151151
if len(models) == 0 {
152-
return nil, errors.New("a bulk write must contain at least one write model")
152+
return nil, ErrEmptySlice
153153
}
154154

155155
if ctx == nil {
@@ -165,6 +165,9 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel,
165165

166166
dispatchModels := make([]driver.WriteModel, len(models))
167167
for i, model := range models {
168+
if model == nil {
169+
return nil, ErrNilDocument
170+
}
168171
dispatchModels[i] = model.convertModel()
169172
}
170173

@@ -271,10 +274,17 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{},
271274
ctx = context.Background()
272275
}
273276

277+
if len(documents) == 0 {
278+
return nil, ErrEmptySlice
279+
}
280+
274281
result := make([]interface{}, len(documents))
275282
docs := make([]bsonx.Doc, len(documents))
276283

277284
for i, doc := range documents {
285+
if doc == nil {
286+
return nil, ErrNilDocument
287+
}
278288
bdoc, insertedID, err := transformAndEnsureID(coll.registry, doc)
279289
if err != nil {
280290
return nil, err
@@ -879,13 +889,9 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
879889
ctx = context.Background()
880890
}
881891

882-
var f bsonx.Doc
883-
var err error
884-
if filter != nil {
885-
f, err = transformDocument(coll.registry, filter)
886-
if err != nil {
887-
return nil, err
888-
}
892+
f, err := transformDocument(coll.registry, filter)
893+
if err != nil {
894+
return nil, err
889895
}
890896

891897
sess := sessionFromContext(ctx)
@@ -934,13 +940,9 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
934940
ctx = context.Background()
935941
}
936942

937-
var f bsonx.Doc
938-
var err error
939-
if filter != nil {
940-
f, err = transformDocument(coll.registry, filter)
941-
if err != nil {
942-
return nil, err
943-
}
943+
f, err := transformDocument(coll.registry, filter)
944+
if err != nil {
945+
return nil, err
944946
}
945947

946948
sess := sessionFromContext(ctx)
@@ -986,13 +988,9 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{},
986988
ctx = context.Background()
987989
}
988990

989-
var f bsonx.Doc
990-
var err error
991-
if filter != nil {
992-
f, err = transformDocument(coll.registry, filter)
993-
if err != nil {
994-
return &SingleResult{err: err}
995-
}
991+
f, err := transformDocument(coll.registry, filter)
992+
if err != nil {
993+
return &SingleResult{err: err}
996994
}
997995

998996
sess := sessionFromContext(ctx)
@@ -1065,13 +1063,9 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}
10651063
ctx = context.Background()
10661064
}
10671065

1068-
var f bsonx.Doc
1069-
var err error
1070-
if filter != nil {
1071-
f, err = transformDocument(coll.registry, filter)
1072-
if err != nil {
1073-
return &SingleResult{err: err}
1074-
}
1066+
f, err := transformDocument(coll.registry, filter)
1067+
if err != nil {
1068+
return &SingleResult{err: err}
10751069
}
10761070

10771071
sess := sessionFromContext(ctx)

mongo/collection_internal_test.go

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package mongo
88

99
import (
1010
"context"
11+
"errors"
1112
"fmt"
1213
"os"
1314
"testing"
@@ -200,16 +201,16 @@ func TestCollection_ReplaceTopologyError(t *testing.T) {
200201
_, err = coll.Aggregate(context.Background(), pipeline, options.Aggregate())
201202
require.Equal(t, err, ErrClientDisconnected)
202203

203-
_, err = coll.Count(context.Background(), nil)
204+
_, err = coll.Count(context.Background(), bsonx.Doc{})
204205
require.Equal(t, err, ErrClientDisconnected)
205206

206-
_, err = coll.CountDocuments(context.Background(), nil)
207+
_, err = coll.CountDocuments(context.Background(), bsonx.Doc{})
207208
require.Equal(t, err, ErrClientDisconnected)
208209

209210
_, err = coll.EstimatedDocumentCount(context.Background())
210211
require.Equal(t, err, ErrClientDisconnected)
211212

212-
_, err = coll.Distinct(context.Background(), "x", nil)
213+
_, err = coll.Distinct(context.Background(), "x", bsonx.Doc{})
213214
require.Equal(t, err, ErrClientDisconnected)
214215

215216
_, err = coll.Find(context.Background(), doc1)
@@ -319,6 +320,95 @@ func TestCollection_InsertOne_WriteConcernError(t *testing.T) {
319320
}
320321
}
321322

323+
func TestCollection_NilDocumentError(t *testing.T) {
324+
if testing.Short() {
325+
t.Skip("skipping integration test in short mode")
326+
}
327+
328+
coll := createTestCollection(t, nil, nil)
329+
330+
_, err := coll.InsertOne(context.Background(), nil)
331+
require.Equal(t, err, ErrNilDocument)
332+
333+
_, err = coll.InsertMany(context.Background(), nil)
334+
require.Equal(t, err, ErrEmptySlice)
335+
336+
_, err = coll.InsertMany(context.Background(), []interface{}{})
337+
require.Equal(t, err, ErrEmptySlice)
338+
339+
_, err = coll.InsertMany(context.Background(), []interface{}{bsonx.Doc{bsonx.Elem{"_id", bsonx.Int32(1)}}, nil})
340+
require.Equal(t, err, ErrNilDocument)
341+
342+
_, err = coll.DeleteOne(context.Background(), nil)
343+
require.Equal(t, err, ErrNilDocument)
344+
345+
_, err = coll.DeleteMany(context.Background(), nil)
346+
require.Equal(t, err, ErrNilDocument)
347+
348+
_, err = coll.UpdateOne(context.Background(), nil, bsonx.Doc{{"$set", bsonx.Document(bsonx.Doc{{"_id", bsonx.Double(3.14159)}})}})
349+
require.Equal(t, err, ErrNilDocument)
350+
351+
_, err = coll.UpdateOne(context.Background(), bsonx.Doc{{"_id", bsonx.Double(3.14159)}}, nil)
352+
require.Equal(t, err, ErrNilDocument)
353+
354+
_, err = coll.UpdateMany(context.Background(), nil, bsonx.Doc{{"$set", bsonx.Document(bsonx.Doc{{"_id", bsonx.Double(3.14159)}})}})
355+
require.Equal(t, err, ErrNilDocument)
356+
357+
_, err = coll.UpdateMany(context.Background(), bsonx.Doc{{"_id", bsonx.Double(3.14159)}}, nil)
358+
require.Equal(t, err, ErrNilDocument)
359+
360+
_, err = coll.ReplaceOne(context.Background(), bsonx.Doc{{"_id", bsonx.Double(3.14159)}}, nil)
361+
require.Equal(t, err, ErrNilDocument)
362+
363+
_, err = coll.ReplaceOne(context.Background(), nil, bsonx.Doc{{"_id", bsonx.Double(3.14159)}})
364+
require.Equal(t, err, ErrNilDocument)
365+
366+
_, err = coll.Count(context.Background(), nil)
367+
require.Equal(t, err, ErrNilDocument)
368+
369+
_, err = coll.CountDocuments(context.Background(), nil)
370+
require.Equal(t, err, ErrNilDocument)
371+
372+
_, err = coll.Distinct(context.Background(), "field", nil)
373+
require.Equal(t, err, ErrNilDocument)
374+
375+
_, err = coll.Find(context.Background(), nil)
376+
require.Equal(t, err, ErrNilDocument)
377+
378+
res := coll.FindOne(context.Background(), nil)
379+
require.Equal(t, res.err, ErrNilDocument)
380+
381+
res = coll.FindOneAndDelete(context.Background(), nil)
382+
require.Equal(t, res.err, ErrNilDocument)
383+
384+
res = coll.FindOneAndReplace(context.Background(), bsonx.Doc{{"_id", bsonx.Double(3.14159)}}, nil)
385+
require.Equal(t, res.err, ErrNilDocument)
386+
387+
res = coll.FindOneAndReplace(context.Background(), nil, bsonx.Doc{{"_id", bsonx.Double(3.14159)}})
388+
require.Equal(t, res.err, ErrNilDocument)
389+
390+
res = coll.FindOneAndUpdate(context.Background(), bsonx.Doc{{"_id", bsonx.Double(3.14159)}}, nil)
391+
require.Equal(t, res.err, ErrNilDocument)
392+
393+
res = coll.FindOneAndUpdate(context.Background(), nil, bsonx.Doc{{"_id", bsonx.Double(3.14159)}})
394+
require.Equal(t, res.err, ErrNilDocument)
395+
396+
_, err = coll.BulkWrite(context.Background(), nil)
397+
require.Equal(t, err, ErrEmptySlice)
398+
399+
_, err = coll.BulkWrite(context.Background(), []WriteModel{})
400+
require.Equal(t, err, ErrEmptySlice)
401+
402+
_, err = coll.BulkWrite(context.Background(), []WriteModel{nil})
403+
require.Equal(t, err, ErrNilDocument)
404+
405+
_, err = coll.Aggregate(context.Background(), nil)
406+
require.Equal(t, err, errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"))
407+
408+
_, err = coll.Watch(context.Background(), nil)
409+
require.Equal(t, err, errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"))
410+
}
411+
322412
func TestCollection_InsertMany(t *testing.T) {
323413
if testing.Short() {
324414
t.Skip("skipping integration test in short mode")
@@ -1215,7 +1305,7 @@ func TestCollection_Count(t *testing.T) {
12151305
coll := createTestCollection(t, nil, nil)
12161306
initCollection(t, coll)
12171307

1218-
count, err := coll.Count(context.Background(), nil)
1308+
count, err := coll.Count(context.Background(), bsonx.Doc{})
12191309
require.Nil(t, err)
12201310
require.Equal(t, count, int64(5))
12211311
}
@@ -1243,7 +1333,7 @@ func TestCollection_Count_withOption(t *testing.T) {
12431333
coll := createTestCollection(t, nil, nil)
12441334
initCollection(t, coll)
12451335

1246-
count, err := coll.Count(context.Background(), nil, options.Count().SetLimit(int64(3)))
1336+
count, err := coll.Count(context.Background(), bsonx.Doc{}, options.Count().SetLimit(int64(3)))
12471337
require.Nil(t, err)
12481338
require.Equal(t, count, int64(3))
12491339
}
@@ -1256,7 +1346,7 @@ func TestCollection_CountDocuments(t *testing.T) {
12561346
col1 := createTestCollection(t, nil, nil)
12571347
initCollection(t, col1)
12581348

1259-
count, err := col1.CountDocuments(context.Background(), nil)
1349+
count, err := col1.CountDocuments(context.Background(), bsonx.Doc{})
12601350
require.Nil(t, err)
12611351
require.Equal(t, count, int64(5))
12621352
}
@@ -1285,7 +1375,7 @@ func TestCollection_CountDocuments_withLimitOptions(t *testing.T) {
12851375
coll := createTestCollection(t, nil, nil)
12861376
initCollection(t, coll)
12871377

1288-
count, err := coll.CountDocuments(context.Background(), nil, options.Count().SetLimit(3))
1378+
count, err := coll.CountDocuments(context.Background(), bsonx.Doc{}, options.Count().SetLimit(3))
12891379
require.Nil(t, err)
12901380
require.Equal(t, count, int64(3))
12911381
}
@@ -1298,7 +1388,7 @@ func TestCollection_CountDocuments_withSkipOptions(t *testing.T) {
12981388
coll := createTestCollection(t, nil, nil)
12991389
initCollection(t, coll)
13001390

1301-
count, err := coll.CountDocuments(context.Background(), nil, options.Count().SetSkip(3))
1391+
count, err := coll.CountDocuments(context.Background(), bsonx.Doc{}, options.Count().SetSkip(3))
13021392
require.Nil(t, err)
13031393
require.Equal(t, count, int64(2))
13041394
}
@@ -1338,7 +1428,7 @@ func TestCollection_Distinct(t *testing.T) {
13381428
coll := createTestCollection(t, nil, nil)
13391429
initCollection(t, coll)
13401430

1341-
results, err := coll.Distinct(context.Background(), "x", nil)
1431+
results, err := coll.Distinct(context.Background(), "x", bsonx.Doc{})
13421432
require.Nil(t, err)
13431433
require.Equal(t, results, []interface{}{int32(1), int32(2), int32(3), int32(4), int32(5)})
13441434
}
@@ -1366,7 +1456,7 @@ func TestCollection_Distinct_withOption(t *testing.T) {
13661456
coll := createTestCollection(t, nil, nil)
13671457
initCollection(t, coll)
13681458

1369-
results, err := coll.Distinct(context.Background(), "x", nil,
1459+
results, err := coll.Distinct(context.Background(), "x", bsonx.Doc{},
13701460
options.Distinct().SetMaxTime(5000000000))
13711461
require.Nil(t, err)
13721462
require.Equal(t, results, []interface{}{int32(1), int32(2), int32(3), int32(4), int32(5)})
@@ -1381,7 +1471,7 @@ func TestCollection_Find_found(t *testing.T) {
13811471
initCollection(t, coll)
13821472

13831473
cursor, err := coll.Find(context.Background(),
1384-
nil,
1474+
bsonx.Doc{},
13851475
options.Find().SetSort(bsonx.Doc{{"x", bsonx.Int32(1)}}),
13861476
)
13871477
require.Nil(t, err)
@@ -1466,7 +1556,7 @@ func TestCollection_Find_Error(t *testing.T) {
14661556
t.Run("TestKillCursor", func(t *testing.T) {
14671557
coll := createTestCollection(t, nil, nil)
14681558
initCollection(t, coll)
1469-
c, err := coll.Find(context.Background(), nil, options.Find().SetBatchSize(2))
1559+
c, err := coll.Find(context.Background(), bsonx.Doc{}, options.Find().SetBatchSize(2))
14701560
require.Nil(t, err, "error running find: %s", err)
14711561

14721562
// exhaust first batch

mongo/crud_util_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ func verifyRunCommandResult(t *testing.T, res bson.Raw, result json.RawMessage)
727727
}
728728

729729
func verifyCollectionContents(t *testing.T, coll *Collection, result json.RawMessage) {
730-
cursor, err := coll.Find(context.Background(), nil)
730+
cursor, err := coll.Find(context.Background(), bsonx.Doc{})
731731
require.NoError(t, err)
732732

733733
verifyCursorResult(t, cursor, result)

0 commit comments

Comments
 (0)