From 446c2e13b9179f35c7ecbac6016f768bbb68da7a Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 7 Oct 2025 09:59:53 -0400 Subject: [PATCH] GODRIVER-3370 Add bypassEmptyTsReplacement option. --- internal/integration/client_test.go | 86 ++++- internal/integration/collection_test.go | 382 ++++++++++++++++++++ mongo/bulk_write.go | 7 + mongo/client.go | 5 + mongo/client_bulk_write.go | 8 + mongo/collection.go | 34 +- x/mongo/driver/operation/find_and_modify.go | 18 + x/mongo/driver/operation/insert.go | 19 + x/mongo/driver/operation/update.go | 18 + x/mongo/driver/xoptions/options.go | 82 +++++ 10 files changed, 656 insertions(+), 3 deletions(-) diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index d37e0fb514..d54a525776 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -32,6 +32,7 @@ import ( "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions" "golang.org/x/sync/errgroup" ) @@ -818,6 +819,87 @@ func TestClient_BulkWrite(t *testing.T) { }) } }) + mt.RunOpts("bulk write with bypassEmptyTsReplacement", mtBulkWriteOpts, func(mt *mtest.T) { + mt.Parallel() + + newOpts := func(option bson.D) *options.ClientBulkWriteOptionsBuilder { + opts := options.ClientBulkWrite() + err := xoptions.SetInternalClientBulkWriteOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + marshalValue := func(val interface{}) bson.RawValue { + t.Helper() + + valType, data, err := bson.MarshalValue(val) + require.Nil(t, err, "MarshalValue error: %v", err) + return bson.RawValue{ + Type: valType, + Value: data, + } + } + + models := []struct { + name string + model mongo.ClientWriteModel + }{ + { + "insert one", + mongo.NewClientInsertOneModel().SetDocument(bson.D{{"x", 1}}), + }, + { + "update one", + mongo.NewClientUpdateOneModel().SetFilter(bson.D{{"x", 1}}).SetUpdate(bson.D{{"$set", bson.D{{"x", 3.14159}}}}), + }, + { + "update many", + mongo.NewClientUpdateManyModel().SetFilter(bson.D{{"x", 1}}).SetUpdate(bson.D{{"$set", bson.D{{"x", 3.14159}}}}), + }, + { + "replace one", + mongo.NewClientReplaceOneModel().SetFilter(bson.D{{"x", 1}}).SetReplacement(bson.D{{"x", 3.14159}}), + }, + } + + testCases := []struct { + name string + opts *options.ClientBulkWriteOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + options.ClientBulkWrite(), + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, m := range models { + for _, tc := range testCases { + mt.Run(fmt.Sprintf("%s %s", m.name, tc.name), func(mt *mtest.T) { + writes := []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: m.model, + }} + _, err := mt.Client.BulkWrite(context.Background(), writes, tc.opts) + require.NoError(mt, err, "BulkWrite error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + } + }) var bulkWrites int cmdMonitor := &event.CommandMonitor{ Started: func(_ context.Context, evt *event.CommandStartedEvent) { @@ -838,8 +920,8 @@ func TestClient_BulkWrite(t *testing.T) { } _, err := mt.Client.BulkWrite(context.Background(), writes) - require.NoError(t, err) - assert.Equal(t, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites) + require.NoError(mt, err) + assert.Equal(mt, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites) }) } diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index c55632022f..8d892c4cd8 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -9,6 +9,7 @@ package integration import ( "context" "errors" + "fmt" "strings" "testing" @@ -22,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions" ) const ( @@ -2028,6 +2030,386 @@ func TestCollection(t *testing.T) { }) } +func TestBypassEmptyTsReplacement(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("5.0")) + + marshalValue := func(val interface{}) bson.RawValue { + t.Helper() + + valType, data, err := bson.MarshalValue(val) + require.NoError(t, err, "MarshalValue error: %v", err) + return bson.RawValue{ + Type: valType, + Value: data, + } + } + + mt.Run("insert one", func(mt *mtest.T) { + doc := bson.D{{"x", 42}} + + newOpts := func(option bson.D) *options.InsertOneOptionsBuilder { + opts := options.InsertOne() + err := xoptions.SetInternalInsertOneOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.InsertOneOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + _, err := mt.Coll.InsertOne(context.Background(), doc, tc.opts) + require.NoError(mt, err, "InsertOne error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s in %v", tc.expected.String()) + }) + } + }) + mt.Run("insert many", func(mt *mtest.T) { + docs := []interface{}{ + bson.D{{"x", 42}}, + bson.D{{"y", "foo"}}, + } + + newOpts := func(option bson.D) *options.InsertManyOptionsBuilder { + opts := options.InsertMany() + err := xoptions.SetInternalInsertManyOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.InsertManyOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + _, err := mt.Coll.InsertMany(context.Background(), docs, tc.opts) + require.NoError(mt, err, "InsertMany error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("update one", func(mt *mtest.T) { + filter := bson.D{{"x", 42}} + update := bson.D{{"$inc", bson.D{{"x", 1}}}} + + newOpts := func(option bson.D) *options.UpdateOneOptionsBuilder { + opts := options.UpdateOne() + err := xoptions.SetInternalUpdateOneOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.UpdateOneOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + _, err := mt.Coll.UpdateOne(context.Background(), filter, update, tc.opts) + require.NoError(mt, err, "UpdateOne error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("update many", func(mt *mtest.T) { + filter := bson.D{{"x", 42}} + update := bson.D{{"$inc", bson.D{{"x", 1}}}} + + newOpts := func(option bson.D) *options.UpdateManyOptionsBuilder { + opts := options.UpdateMany() + err := xoptions.SetInternalUpdateManyOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.UpdateManyOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + _, err := mt.Coll.UpdateMany(context.Background(), filter, update, tc.opts) + require.NoError(mt, err, "UpdateMany error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("replace one", func(mt *mtest.T) { + filter := bson.D{{"x", 42}} + replacement := bson.D{{"y", "foo"}} + + newOpts := func(option bson.D) *options.ReplaceOptionsBuilder { + opts := options.Replace() + err := xoptions.SetInternalReplaceOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.ReplaceOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + _, err := mt.Coll.ReplaceOne(context.Background(), filter, replacement, tc.opts) + require.NoError(mt, err, "ReplaceOne error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("find one and update", func(mt *mtest.T) { + filter := bson.D{{"x", 1}} + update := bson.D{{"$inc", bson.D{{"x", 1}}}} + + newOpts := func(option bson.D) *options.FindOneAndUpdateOptionsBuilder { + opts := options.FindOneAndUpdate() + err := xoptions.SetInternalFindOneAndUpdateOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.FindOneAndUpdateOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + _, err := mt.Coll.FindOneAndUpdate(context.Background(), filter, update, tc.opts).Raw() + require.NoError(mt, err, "FindOneAndUpdate error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("find one and replace", func(mt *mtest.T) { + filter := bson.D{{"x", 1}} + replacement := bson.D{{"y", "foo"}} + + newOpts := func(option bson.D) *options.FindOneAndReplaceOptionsBuilder { + opts := options.FindOneAndReplace() + err := xoptions.SetInternalFindOneAndReplaceOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + testCases := []struct { + name string + opts *options.FindOneAndReplaceOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + nil, + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, tc := range testCases { + mt.Run(tc.name, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + _, err := mt.Coll.FindOneAndReplace(context.Background(), filter, replacement, tc.opts).Raw() + require.NoError(mt, err, "FindOneAndReplace error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + }) + mt.Run("bulk write", func(mt *mtest.T) { + newOpts := func(option bson.D) *options.BulkWriteOptionsBuilder { + opts := options.BulkWrite() + err := xoptions.SetInternalBulkWriteOptions(opts, "addCommandFields", option) + require.NoError(mt, err, "unexpected error: %v", err) + return opts + } + + models := []struct { + name string + model mongo.WriteModel + }{ + { + "insert one", + mongo.NewInsertOneModel().SetDocument(bson.D{{"_id", "id1"}}), + }, + { + "update one", + mongo.NewUpdateOneModel().SetFilter(bson.D{{"_id", "id3"}}).SetUpdate(bson.D{{"$set", bson.D{{"_id", 3.14159}}}}), + }, + { + "update many", + mongo.NewUpdateManyModel().SetFilter(bson.D{{"_id", "id3"}}).SetUpdate(bson.D{{"$set", bson.D{{"_id", 3.14159}}}}), + }, + { + "replace one", + mongo.NewReplaceOneModel().SetFilter(bson.D{{"_id", "id3"}}).SetReplacement(bson.D{{"_id", 3.14159}}), + }, + } + + testCases := []struct { + name string + opts *options.BulkWriteOptionsBuilder + expected bson.RawValue + }{ + { + "empty", + options.BulkWrite(), + bson.RawValue{}, + }, + { + "false", + newOpts(bson.D{{"bypassEmptyTsReplacement", false}}), + marshalValue(false), + }, + { + "true", + newOpts(bson.D{{"bypassEmptyTsReplacement", true}}), + marshalValue(true), + }, + } + for _, m := range models { + for _, tc := range testCases { + mt.Run(fmt.Sprintf("%s %s", m.name, tc.name), func(mt *mtest.T) { + _, err := mt.Coll.BulkWrite(context.Background(), []mongo.WriteModel{m.model}, tc.opts) + require.NoError(mt, err, "BulkWrite error: %v", err) + evt := mt.GetStartedEvent() + val := evt.Command.Lookup("bypassEmptyTsReplacement") + assert.Equal(mt, tc.expected, val, "expected bypassEmptyTsReplacement to be %s", tc.expected.String()) + }) + } + } + }) +} + func initCollection(tb testing.TB, coll *mongo.Collection) { tb.Helper() diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 036e3badb1..ca683c1dc7 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -40,6 +40,7 @@ type bulkWrite struct { result BulkWriteResult let any rawData *bool + additionalCmd bson.D } func (bw *bulkWrite) execute(ctx context.Context) error { @@ -213,6 +214,9 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera if bw.rawData != nil { op.RawData(*bw.rawData) } + if len(bw.additionalCmd) > 0 { + op.AdditionalCmd(bw.additionalCmd) + } err := op.Execute(ctx) @@ -427,6 +431,9 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera if bw.rawData != nil { op.RawData(*bw.rawData) } + if len(bw.additionalCmd) > 0 { + op.AdditionalCmd(bw.additionalCmd) + } err := op.Execute(ctx) diff --git a/mongo/client.go b/mongo/client.go index f0480a0c72..3ccdf8b792 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -955,6 +955,11 @@ func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, op.rawData = &rawData } } + if additionalCmd := optionsutil.Value(bwo.Internal, "addCommandFields"); additionalCmd != nil { + if ac, ok := additionalCmd.(bson.D); ok { + op.additionalCmd = ac + } + } if bwo.VerboseResults == nil || !(*bwo.VerboseResults) { op.errorsOnly = true } else if !acknowledged { diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 27c3ad3ce4..cb9d8cb4d8 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -45,6 +45,7 @@ type clientBulkWrite struct { selector description.ServerSelector writeConcern *writeconcern.WriteConcern rawData *bool + additionalCmd bson.D result ClientBulkWriteResult } @@ -148,6 +149,13 @@ func (bw *clientBulkWrite) newCommand() func([]byte, description.SelectedServer) if bw.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) { dst = bsoncore.AppendBooleanElement(dst, "rawData", *bw.rawData) } + if len(bw.additionalCmd) > 0 { + doc, err := bson.Marshal(bw.additionalCmd) + if err != nil { + return nil, err + } + dst = append(dst, doc[4:len(doc)-1]...) + } return dst, nil } } diff --git a/mongo/collection.go b/mongo/collection.go index ef4188d67b..36f8bd307a 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -251,6 +251,11 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel, op.rawData = &rawData } } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + if d, ok := rawDataOpt.(bson.D); ok { + op.additionalCmd = d + } + } err = op.execute(ctx) @@ -335,6 +340,11 @@ func (coll *Collection) insert( op = op.RawData(rawData) } } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + if d, ok := rawDataOpt.(bson.D); ok { + op = op.AdditionalCmd(d) + } + } retry := driver.RetryNone if coll.client.retryWrites { retry = driver.RetryOncePerCommand @@ -388,7 +398,14 @@ func (coll *Collection) InsertOne(ctx context.Context, document any, } if rawDataOpt := optionsutil.Value(args.Internal, "rawData"); rawDataOpt != nil { imOpts.Opts = append(imOpts.Opts, func(opts *options.InsertManyOptions) error { - optionsutil.WithValue(opts.Internal, "rawData", rawDataOpt) + opts.Internal = optionsutil.WithValue(opts.Internal, "rawData", rawDataOpt) + + return nil + }) + } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + imOpts.Opts = append(imOpts.Opts, func(opts *options.InsertManyOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, "addCommandFields", rawDataOpt) return nil }) @@ -710,6 +727,11 @@ func (coll *Collection) updateOrReplace( op = op.RawData(rawData) } } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + if d, ok := rawDataOpt.(bson.D); ok { + op = op.AdditionalCmd(d) + } + } retry := driver.RetryNone // retryable writes are only enabled updateOne/replaceOne operations if !multi && coll.client.retryWrites { @@ -1864,6 +1886,11 @@ func (coll *Collection) FindOneAndReplace( op = op.RawData(rawData) } } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + if d, ok := rawDataOpt.(bson.D); ok { + op = op.AdditionalCmd(d) + } + } return coll.findAndModify(ctx, op) } @@ -1978,6 +2005,11 @@ func (coll *Collection) FindOneAndUpdate( op = op.RawData(rawData) } } + if rawDataOpt := optionsutil.Value(args.Internal, "addCommandFields"); rawDataOpt != nil { + if d, ok := rawDataOpt.(bson.D); ok { + op = op.AdditionalCmd(d) + } + } return coll.findAndModify(ctx, op) } diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 0920568110..0b3da9c4dc 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -51,6 +51,7 @@ type FindAndModify struct { let bsoncore.Document timeout *time.Duration rawData *bool + additionalCmd bson.D result FindAndModifyResult } @@ -216,6 +217,13 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( if fam.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) { dst = bsoncore.AppendBooleanElement(dst, "rawData", *fam.rawData) } + if len(fam.additionalCmd) > 0 { + doc, err := bson.Marshal(fam.additionalCmd) + if err != nil { + return nil, err + } + dst = append(dst, doc[4:len(doc)-1]...) + } return dst, nil } @@ -491,3 +499,13 @@ func (fam *FindAndModify) RawData(rawData bool) *FindAndModify { fam.rawData = &rawData return fam } + +// AdditionalCmd sets additional command fields to be attached. +func (fam *FindAndModify) AdditionalCmd(d bson.D) *FindAndModify { + if fam == nil { + fam = new(FindAndModify) + } + + fam.additionalCmd = d + return fam +} diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 57d461ae3b..d4f01e6b92 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -12,6 +12,7 @@ import ( "fmt" "time" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/driverutil" "go.mongodb.org/mongo-driver/v2/internal/logger" @@ -43,6 +44,7 @@ type Insert struct { serverAPI *driver.ServerAPIOptions timeout *time.Duration rawData *bool + additionalCmd bson.D logger *logger.Logger } @@ -137,6 +139,13 @@ func (i *Insert) command(dst []byte, desc description.SelectedServer) ([]byte, e if i.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) { dst = bsoncore.AppendBooleanElement(dst, "rawData", *i.rawData) } + if len(i.additionalCmd) > 0 { + doc, err := bson.Marshal(i.additionalCmd) + if err != nil { + return nil, err + } + dst = append(dst, doc[4:len(doc)-1]...) + } return dst, nil } @@ -333,3 +342,13 @@ func (i *Insert) RawData(rawData bool) *Insert { i.rawData = &rawData return i } + +// AdditionalCmd sets additional command fields to be attached. +func (i *Insert) AdditionalCmd(d bson.D) *Insert { + if i == nil { + i = new(Insert) + } + + i.additionalCmd = d + return i +} diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 00e193ef49..07848a54fa 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -47,6 +47,7 @@ type Update struct { let bsoncore.Document timeout *time.Duration rawData *bool + additionalCmd bson.D logger *logger.Logger } @@ -208,6 +209,13 @@ func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, e if u.rawData != nil && desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 27) { dst = bsoncore.AppendBooleanElement(dst, "rawData", *u.rawData) } + if len(u.additionalCmd) > 0 { + doc, err := bson.Marshal(u.additionalCmd) + if err != nil { + return nil, err + } + dst = append(dst, doc[4:len(doc)-1]...) + } return dst, nil } @@ -437,3 +445,13 @@ func (u *Update) RawData(rawData bool) *Update { u.rawData = &rawData return u } + +// AdditionalCmd sets additional command fields to be attached. +func (u *Update) AdditionalCmd(d bson.D) *Update { + if u == nil { + u = new(Update) + } + + u.additionalCmd = d + return u +} diff --git a/x/mongo/driver/xoptions/options.go b/x/mongo/driver/xoptions/options.go index fa11dd60b8..c3a0938ebe 100644 --- a/x/mongo/driver/xoptions/options.go +++ b/x/mongo/driver/xoptions/options.go @@ -9,6 +9,7 @@ package xoptions import ( "fmt" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -80,6 +81,15 @@ func SetInternalBulkWriteOptions(a *options.BulkWriteOptionsBuilder, key string, opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.BulkWriteOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -101,6 +111,15 @@ func SetInternalClientBulkWriteOptions(a *options.ClientBulkWriteOptionsBuilder, opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.ClientBulkWriteOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -332,6 +351,15 @@ func SetInternalFindOneAndReplaceOptions(a *options.FindOneAndReplaceOptionsBuil opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.FindOneAndReplaceOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -353,6 +381,15 @@ func SetInternalFindOneAndUpdateOptions(a *options.FindOneAndUpdateOptionsBuilde opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.FindOneAndUpdateOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -374,6 +411,15 @@ func SetInternalInsertManyOptions(a *options.InsertManyOptionsBuilder, key strin opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.InsertManyOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -395,6 +441,15 @@ func SetInternalInsertOneOptions(a *options.InsertOneOptionsBuilder, key string, opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.InsertOneOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -458,6 +513,15 @@ func SetInternalReplaceOptions(a *options.ReplaceOptionsBuilder, key string, opt opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.ReplaceOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -479,6 +543,15 @@ func SetInternalUpdateManyOptions(a *options.UpdateManyOptionsBuilder, key strin opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.UpdateManyOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) } @@ -500,6 +573,15 @@ func SetInternalUpdateOneOptions(a *options.UpdateOneOptionsBuilder, key string, opts.Internal = optionsutil.WithValue(opts.Internal, key, b) return nil }) + case "addCommandFields": + d, ok := option.(bson.D) + if !ok { + return typeErrFunc("bson.D") + } + a.Opts = append(a.Opts, func(opts *options.UpdateOneOptions) error { + opts.Internal = optionsutil.WithValue(opts.Internal, key, d) + return nil + }) default: return fmt.Errorf("unsupported option: %q", key) }