Skip to content

Commit cd0e792

Browse files
author
Divjot Arora
committed
Fix type checks for update pipelines
GODRIVER-1231 Change-Id: Ie60d386efce586d0a271b487698fd0bfd13cd74c
1 parent b88fd0d commit cd0e792

File tree

3 files changed

+76
-20
lines changed

3 files changed

+76
-20
lines changed

mongo/collection.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ func (coll *Collection) DeleteMany(ctx context.Context, filter interface{},
430430
return coll.delete(ctx, filter, false, rrMany, opts...)
431431
}
432432

433-
func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Document, update interface{}, multi bool, expectedRr returnResult,
434-
opts ...*options.UpdateOptions) (*UpdateResult, error) {
433+
func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Document, update interface{}, multi bool,
434+
expectedRr returnResult, checkDollarKey bool, opts ...*options.UpdateOptions) (*UpdateResult, error) {
435435

436436
if ctx == nil {
437437
ctx = context.Background()
@@ -441,16 +441,11 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc
441441
uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
442442
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", filter)
443443

444-
switch update.(type) {
445-
case bsoncore.Document:
446-
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "u", update.(bsoncore.Document))
447-
default:
448-
u, err := transformUpdateValue(coll.registry, update, true)
449-
if err != nil {
450-
return nil, err
451-
}
452-
updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
444+
u, err := transformUpdateValue(coll.registry, update, checkDollarKey)
445+
if err != nil {
446+
return nil, err
453447
}
448+
updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
454449
if multi {
455450
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
456451
}
@@ -482,7 +477,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc
482477
defer sess.EndSession()
483478
}
484479

485-
err := coll.client.validSession(sess)
480+
err = coll.client.validSession(sess)
486481
if err != nil {
487482
return nil, err
488483
}
@@ -546,7 +541,7 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat
546541
return nil, err
547542
}
548543

549-
return coll.updateOrReplace(ctx, f, update, false, rrOne, opts...)
544+
return coll.updateOrReplace(ctx, f, update, false, rrOne, true, opts...)
550545
}
551546

552547
// UpdateMany updates multiple documents in the collection.
@@ -562,7 +557,7 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda
562557
return nil, err
563558
}
564559

565-
return coll.updateOrReplace(ctx, f, update, true, rrMany, opts...)
560+
return coll.updateOrReplace(ctx, f, update, true, rrMany, true, opts...)
566561
}
567562

568563
// ReplaceOne replaces a single document in the collection.
@@ -596,7 +591,7 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{},
596591
updateOptions = append(updateOptions, uOpts)
597592
}
598593

599-
return coll.updateOrReplace(ctx, f, r, false, rrOne, updateOptions...)
594+
return coll.updateOrReplace(ctx, f, r, false, rrOne, false, updateOptions...)
600595
}
601596

602597
// Aggregate runs an aggregation framework pipeline.

mongo/collection_internal_test.go

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ import (
1212
"fmt"
1313
"os"
1414
"testing"
15-
16-
"go.mongodb.org/mongo-driver/mongo/options"
17-
"go.mongodb.org/mongo-driver/x/bsonx"
18-
1915
"time"
2016

2117
"github.com/google/go-cmp/cmp"
@@ -25,9 +21,12 @@ import (
2521
"go.mongodb.org/mongo-driver/bson/primitive"
2622
"go.mongodb.org/mongo-driver/event"
2723
"go.mongodb.org/mongo-driver/internal/testutil"
24+
"go.mongodb.org/mongo-driver/mongo/options"
2825
"go.mongodb.org/mongo-driver/mongo/readconcern"
2926
"go.mongodb.org/mongo-driver/mongo/readpref"
3027
"go.mongodb.org/mongo-driver/mongo/writeconcern"
28+
"go.mongodb.org/mongo-driver/x/bsonx"
29+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
3130
)
3231

3332
var impossibleWriteConcern = writeconcern.New(writeconcern.W(50), writeconcern.WTimeout(time.Second))
@@ -2221,3 +2220,44 @@ func TestCollection_BulkWrite(t *testing.T) {
22212220
}
22222221
})
22232222
}
2223+
2224+
// test special types that should be converted to a document for updates even though the underlying type is a
2225+
// slice/array
2226+
func TestCollection_Update_SpecialSliceTypes(t *testing.T) {
2227+
doc := bson.D{{"$set", bson.D{{"x", 2}}}}
2228+
docBytes, err := bson.Marshal(doc)
2229+
require.NoError(t, err, "error getting document bytes: %v", err)
2230+
xUpdate := bsonx.Doc{{"x", bsonx.Int32(2)}}
2231+
xDoc := bsonx.Doc{{"$set", bsonx.Document(xUpdate)}}
2232+
2233+
testCases := []struct {
2234+
name string
2235+
update interface{}
2236+
}{
2237+
{"bsoncore Document", bsoncore.Document(docBytes)},
2238+
{"bson Raw", bson.Raw(docBytes)},
2239+
{"bson D", doc},
2240+
{"bsonx Document", xDoc},
2241+
{"byte slice", docBytes},
2242+
}
2243+
2244+
for _, tc := range testCases {
2245+
t.Run(tc.name, func(t *testing.T) {
2246+
coll := createTestCollection(t, nil, &tc.name)
2247+
defer func() {
2248+
_ = coll.Drop(ctx)
2249+
}()
2250+
2251+
insertedDoc := bson.D{{"x", 1}}
2252+
_, err = coll.InsertOne(ctx, insertedDoc)
2253+
require.NoError(t, err, "error inserting document: %v", err)
2254+
2255+
res, err := coll.UpdateOne(ctx, insertedDoc, tc.update)
2256+
require.NoError(t, err, "error updating document: %v", err)
2257+
require.Equal(t, int64(1), res.MatchedCount,
2258+
"matched count mismatch; expected %d, got %d", 1, res.MatchedCount)
2259+
require.Equal(t, int64(1), res.ModifiedCount,
2260+
"modified count mismatch; expected %d, got %d", 1, res.ModifiedCount)
2261+
})
2262+
}
2263+
}

mongo/mongo.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,34 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, chec
339339
switch t := update.(type) {
340340
case nil:
341341
return u, ErrNilDocument
342-
case primitive.D:
342+
case primitive.D, bsonx.Doc:
343343
u.Type = bsontype.EmbeddedDocument
344344
u.Data, err = transformBsoncoreDocument(registry, update)
345345
if err != nil {
346346
return u, err
347347
}
348348

349+
if checkDocDollarKey {
350+
err = ensureDollarKeyv2(u.Data)
351+
}
352+
return u, err
353+
case bson.Raw:
354+
u.Type = bsontype.EmbeddedDocument
355+
u.Data = t
356+
if checkDocDollarKey {
357+
err = ensureDollarKeyv2(u.Data)
358+
}
359+
return u, err
360+
case bsoncore.Document:
361+
u.Type = bsontype.EmbeddedDocument
362+
u.Data = t
363+
if checkDocDollarKey {
364+
err = ensureDollarKeyv2(u.Data)
365+
}
366+
return u, err
367+
case []byte:
368+
u.Type = bsontype.EmbeddedDocument
369+
u.Data = t
349370
if checkDocDollarKey {
350371
err = ensureDollarKeyv2(u.Data)
351372
}

0 commit comments

Comments
 (0)