Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 66 additions & 60 deletions mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,44 +333,39 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera

switch converted := model.(type) {
case *ReplaceOneModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Replacement,
converted.Hint,
nil,
converted.Collation,
converted.Upsert,
false,
false,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Replacement,
hint: converted.Hint,
sort: converted.Sort,
collation: converted.Collation,
upsert: converted.Upsert,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
case *UpdateOneModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Update,
converted.Hint,
converted.ArrayFilters,
converted.Collation,
converted.Upsert,
false,
true,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Update,
hint: converted.Hint,
sort: converted.Sort,
arrayFilters: converted.ArrayFilters,
collation: converted.Collation,
upsert: converted.Upsert,
checkDollarKey: true,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
case *UpdateManyModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Update,
converted.Hint,
converted.ArrayFilters,
converted.Collation,
converted.Upsert,
true,
true,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Update,
hint: converted.Hint,
arrayFilters: converted.ArrayFilters,
collation: converted.Collation,
upsert: converted.Upsert,
multi: true,
checkDollarKey: true,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
}
Expand Down Expand Up @@ -420,62 +415,73 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera
return op.Result(), err
}

func createUpdateDoc(
filter interface{},
update interface{},
hint interface{},
arrayFilters *options.ArrayFilters,
collation *options.Collation,
upsert *bool,
multi bool,
checkDollarKey bool,
bsonOpts *options.BSONOptions,
registry *bsoncodec.Registry,
) (bsoncore.Document, error) {
f, err := marshal(filter, bsonOpts, registry)
type updateDoc struct {
filter interface{}
update interface{}
hint interface{}
sort interface{}
arrayFilters *options.ArrayFilters
collation *options.Collation
upsert *bool
multi bool
checkDollarKey bool
}

func (doc updateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) {
f, err := marshal(doc.filter, bsonOpts, registry)
if err != nil {
return nil, err
}

uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)

u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey)
u, err := marshalUpdateValue(doc.update, bsonOpts, registry, doc.checkDollarKey)
if err != nil {
return nil, err
}

updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)

if multi {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
if doc.multi {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", doc.multi)
}
if doc.sort != nil {
if isUnorderedMap(doc.sort) {
return nil, ErrMapForOrderedArgument{"sort"}
}
s, err := marshal(doc.sort, bsonOpts, registry)
if err != nil {
return nil, err
}
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "sort", s)
}

if arrayFilters != nil {
if doc.arrayFilters != nil {
reg := registry
if arrayFilters.Registry != nil {
reg = arrayFilters.Registry
if doc.arrayFilters.Registry != nil {
reg = doc.arrayFilters.Registry
}
arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg)
arr, err := marshalValue(doc.arrayFilters.Filters, bsonOpts, reg)
if err != nil {
return nil, err
}
updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data)
}

if collation != nil {
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
if doc.collation != nil {
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(doc.collation.ToDocument()))
}

if upsert != nil {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
if doc.upsert != nil {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *doc.upsert)
}

if hint != nil {
if isUnorderedMap(hint) {
if doc.hint != nil {
if isUnorderedMap(doc.hint) {
return nil, ErrMapForOrderedArgument{"hint"}
}
hintVal, err := marshalValue(hint, bsonOpts, registry)
hintVal, err := marshalValue(doc.hint, bsonOpts, registry)
if err != nil {
return nil, err
}
Expand Down
18 changes: 18 additions & 0 deletions mongo/bulk_write_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type ReplaceOneModel struct {
Filter interface{}
Replacement interface{}
Hint interface{}
Sort interface{}
}

// NewReplaceOneModel creates a new ReplaceOneModel.
Expand Down Expand Up @@ -173,6 +174,14 @@ func (rom *ReplaceOneModel) SetUpsert(upsert bool) *ReplaceOneModel {
return rom
}

// SetSort specifies which document the operation replaces if the query matches multiple documents. The first document
// matched by the sort order will be replaced. This option is only valid for MongoDB versions >= 8.0. The driver will
// return an error if the sort parameter is a multi-key map. The default value is nil.
func (rom *ReplaceOneModel) SetSort(sort interface{}) *ReplaceOneModel {
rom.Sort = &sort
return rom
}

func (*ReplaceOneModel) writeModel() {}

// UpdateOneModel is used to update at most one document in a BulkWrite operation.
Expand All @@ -183,6 +192,7 @@ type UpdateOneModel struct {
Update interface{}
ArrayFilters *options.ArrayFilters
Hint interface{}
Sort interface{}
}

// NewUpdateOneModel creates a new UpdateOneModel.
Expand Down Expand Up @@ -238,6 +248,14 @@ func (uom *UpdateOneModel) SetUpsert(upsert bool) *UpdateOneModel {
return uom
}

// SetSort specifies which document the operation updates if the query matches multiple documents. The first document
// matched by the sort order will be updated. This option is only valid for MongoDB versions >= 8.0. The driver will
// return an error if the sort parameter is a multi-key map. The default value is nil.
func (uom *UpdateOneModel) SetSort(sort interface{}) *UpdateOneModel {
uom.Sort = sort
return uom
}

func (*UpdateOneModel) writeModel() {}

// UpdateManyModel is used to update multiple documents in a BulkWrite operation.
Expand Down
24 changes: 12 additions & 12 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,17 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc

// collation, arrayFilters, upsert, and hint are included on the individual update documents rather than as part of the
// command
updateDoc, err := createUpdateDoc(
filter,
update,
uo.Hint,
uo.ArrayFilters,
uo.Collation,
uo.Upsert,
multi,
checkDollarKey,
coll.bsonOpts,
coll.registry)
updateDoc, err := updateDoc{
filter: filter,
update: update,
hint: uo.Hint,
sort: uo.Sort,
arrayFilters: uo.ArrayFilters,
collation: uo.Collation,
upsert: uo.Upsert,
multi: multi,
checkDollarKey: checkDollarKey,
}.marshal(coll.bsonOpts, coll.registry)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -598,7 +598,6 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc
}
op = op.Let(let)
}

if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation {
op = op.BypassDocumentValidation(*uo.BypassDocumentValidation)
}
Expand Down Expand Up @@ -760,6 +759,7 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{},
uOpts.Hint = opt.Hint
uOpts.Let = opt.Let
uOpts.Comment = opt.Comment
uOpts.Sort = opt.Sort
updateOptions = append(updateOptions, uOpts)
}

Expand Down
11 changes: 11 additions & 0 deletions mongo/integration/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,17 @@ func TestCollection(t *testing.T) {
assert.Equal(mt, int64(0), res.ModifiedCount, "expected modified count 0, got %v", res.ModifiedCount)
assert.NotNil(mt, res.UpsertedID, "expected upserted ID, got nil")
})
// Require 8.0 servers for sort support.
mt.RunOpts("error with sort", mtest.NewOptions().MinServerVersion("8.0"), func(mt *mtest.T) {
filter := bson.D{{"x", bson.D{{"$gte", 3}}}}
update := bson.D{{"$inc", bson.D{{"x", 1}}}}

_, err := mt.Coll.UpdateMany(context.Background(),
filter, update,
&options.UpdateOptions{Sort: bson.D{{"_id", -1}}},
)
assert.ErrorContains(t, err, "Cannot specify sort with multi=true", "expected an error on UpdateMany with sort")
})
mt.Run("write error", func(mt *mtest.T) {
filter := bson.D{{"_id", "foo"}}
update := bson.D{{"$set", bson.D{{"_id", 3.14159}}}}
Expand Down
12 changes: 12 additions & 0 deletions mongo/integration/unified/bulkwrite_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) {
return nil, fmt.Errorf("error creating hint: %w", err)
}
uom.SetHint(hint)
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
uom.SetSort(sort)
case "update":
update, err = createUpdateValue(val)
if err != nil {
Expand Down Expand Up @@ -242,6 +248,12 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) {
return nil, fmt.Errorf("error creating hint: %w", err)
}
rom.SetHint(hint)
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
rom.SetSort(sort)
case "replacement":
replacement = val.Document()
case "upsert":
Expand Down
6 changes: 6 additions & 0 deletions mongo/integration/unified/collection_operation_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,12 @@ func executeReplaceOne(ctx context.Context, operation *operation) (*operationRes
return nil, fmt.Errorf("error creating hint: %w", err)
}
opts.SetHint(hint)
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
opts.SetSort(sort)
case "replacement":
replacement = val.Document()
case "upsert":
Expand Down
14 changes: 14 additions & 0 deletions mongo/integration/unified/crud_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func createUpdateArguments(args bson.Raw) (*updateArguments, error) {
ua.opts.SetHint(hint)
case "let":
ua.opts.SetLet(val.Document())
case "sort":
ua.opts.SetSort(val.Document())
case "update":
ua.update, err = createUpdateValue(val)
if err != nil {
Expand Down Expand Up @@ -160,6 +162,18 @@ func createHint(val bson.RawValue) (interface{}, error) {
return hint, nil
}

func createSort(val bson.RawValue) (interface{}, error) {
var sort interface{}

switch val.Type {
case bsontype.EmbeddedDocument:
sort = val.Document()
default:
return nil, fmt.Errorf("unrecognized sort value type %s", val.Type)
}
return sort, nil
}

func createCommentString(val bson.RawValue) (string, error) {
switch val.Type {
case bsontype.String:
Expand Down
15 changes: 15 additions & 0 deletions mongo/options/replaceoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ type ReplaceOptions struct {
// Values must be constant or closed expressions that do not reference document fields. Parameters can then be
// accessed as variables in an aggregate expression context (e.g. "$$var").
Let interface{}

// A document specifying which document should be replaced if the filter used by the operation matches multiple
// documents in the collection. If set, the first document in the sorted order will be replaced. This option is
// only valid for MongoDB versions >= 8.0. The driver will return an error if the sort parameter is a multi-key
// map. The default value is nil.
Sort interface{}
}

// Replace creates a new ReplaceOptions instance.
Expand Down Expand Up @@ -83,6 +89,12 @@ func (ro *ReplaceOptions) SetLet(l interface{}) *ReplaceOptions {
return ro
}

// SetSort sets the value for the Sort field.
func (ro *ReplaceOptions) SetSort(s interface{}) *ReplaceOptions {
ro.Sort = s
return ro
}

// MergeReplaceOptions combines the given ReplaceOptions instances into a single ReplaceOptions in a last-one-wins
// fashion.
//
Expand Down Expand Up @@ -112,6 +124,9 @@ func MergeReplaceOptions(opts ...*ReplaceOptions) *ReplaceOptions {
if ro.Let != nil {
rOpts.Let = ro.Let
}
if ro.Sort != nil {
rOpts.Sort = ro.Sort
}
}

return rOpts
Expand Down
Loading
Loading